refactor: replace magic strings with named constants
- PricingSourceChannel/LiteLLM/Fallback for resolver source - MediaTypeImage/Video/Prompt for result.MediaType - Reuse BillingModeToken/BillingModeImage for billing mode - Reuse BillingModelSourceChannelMapped/PlatformAnthropic in handler
This commit is contained in:
@@ -130,7 +130,7 @@ func channelToResponse(ch *service.Channel) *channelResponse {
|
||||
}
|
||||
resp.BillingModelSource = ch.BillingModelSource
|
||||
if resp.BillingModelSource == "" {
|
||||
resp.BillingModelSource = "channel_mapped"
|
||||
resp.BillingModelSource = service.BillingModelSourceChannelMapped
|
||||
}
|
||||
if resp.GroupIDs == nil {
|
||||
resp.GroupIDs = []int64{}
|
||||
@@ -147,11 +147,11 @@ func channelToResponse(ch *service.Channel) *channelResponse {
|
||||
}
|
||||
billingMode := string(p.BillingMode)
|
||||
if billingMode == "" {
|
||||
billingMode = "token"
|
||||
billingMode = string(service.BillingModeToken)
|
||||
}
|
||||
platform := p.Platform
|
||||
if platform == "" {
|
||||
platform = "anthropic"
|
||||
platform = service.PlatformAnthropic
|
||||
}
|
||||
intervals := make([]pricingIntervalResponse, 0, len(p.Intervals))
|
||||
for _, iv := range p.Intervals {
|
||||
@@ -194,7 +194,7 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe
|
||||
}
|
||||
platform := r.Platform
|
||||
if platform == "" {
|
||||
platform = "anthropic"
|
||||
platform = service.PlatformAnthropic
|
||||
}
|
||||
intervals := make([]service.PricingInterval, 0, len(r.Intervals))
|
||||
for _, iv := range r.Intervals {
|
||||
|
||||
@@ -60,6 +60,19 @@ const (
|
||||
claudeMimicDebugInfoKey = "claude_mimic_debug_info"
|
||||
)
|
||||
|
||||
// MediaType 媒体类型常量
|
||||
const (
|
||||
MediaTypeImage = "image"
|
||||
MediaTypeVideo = "video"
|
||||
MediaTypePrompt = "prompt"
|
||||
)
|
||||
|
||||
const (
|
||||
claudeMaxMessageOverheadTokens = 3
|
||||
claudeMaxBlockOverheadTokens = 1
|
||||
claudeMaxUnknownContentTokens = 4
|
||||
)
|
||||
|
||||
// ForceCacheBillingContextKey 强制缓存计费上下文键
|
||||
// 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费
|
||||
type forceCacheBillingKeyType struct{}
|
||||
@@ -7744,7 +7757,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
}
|
||||
|
||||
// 根据请求类型选择计费方式
|
||||
if result.MediaType == "image" || result.MediaType == "video" {
|
||||
if result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo {
|
||||
var soraConfig *SoraPriceConfig
|
||||
if apiKey.Group != nil {
|
||||
soraConfig = &SoraPriceConfig{
|
||||
@@ -7754,12 +7767,12 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
VideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD,
|
||||
}
|
||||
}
|
||||
if result.MediaType == "image" {
|
||||
if result.MediaType == MediaTypeImage {
|
||||
cost = s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier)
|
||||
} else {
|
||||
cost = s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier)
|
||||
}
|
||||
} else if result.MediaType == "prompt" {
|
||||
} else if result.MediaType == MediaTypePrompt {
|
||||
cost = &CostBreakdown{}
|
||||
} else if result.ImageCount > 0 {
|
||||
// 图片生成计费:渠道级别定价优先,否则走按次计费(兼容旧版本)
|
||||
@@ -7767,7 +7780,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
if s.resolver != nil && apiKey.Group != nil {
|
||||
gid := apiKey.Group.ID
|
||||
resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid})
|
||||
if resolved.Source == "channel" {
|
||||
if resolved.Source == PricingSourceChannel {
|
||||
hasChannelPricing = true
|
||||
}
|
||||
}
|
||||
@@ -7900,15 +7913,15 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
}
|
||||
|
||||
// 设置计费模式
|
||||
if result.MediaType != "image" && result.MediaType != "video" && result.MediaType != "prompt" {
|
||||
if result.MediaType != MediaTypeImage && result.MediaType != MediaTypeVideo && result.MediaType != MediaTypePrompt {
|
||||
if cost != nil && cost.BillingMode != "" {
|
||||
billingMode := cost.BillingMode
|
||||
usageLog.BillingMode = &billingMode
|
||||
} else if result.ImageCount > 0 {
|
||||
billingMode := "image"
|
||||
billingMode := string(BillingModeImage)
|
||||
usageLog.BillingMode = &billingMode
|
||||
} else {
|
||||
billingMode := "token"
|
||||
billingMode := string(BillingModeToken)
|
||||
usageLog.BillingMode = &billingMode
|
||||
}
|
||||
}
|
||||
@@ -8038,7 +8051,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
if s.resolver != nil && apiKey.Group != nil {
|
||||
gid := apiKey.Group.ID
|
||||
resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid})
|
||||
if resolved.Source == "channel" {
|
||||
if resolved.Source == PricingSourceChannel {
|
||||
hasChannelPricing = true
|
||||
}
|
||||
}
|
||||
@@ -8094,7 +8107,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
Model: billingModel,
|
||||
GroupID: &gid,
|
||||
})
|
||||
if resolved.Source == "channel" {
|
||||
if resolved.Source == PricingSourceChannel {
|
||||
// 有渠道定价,渠道区间已包含上下文分层
|
||||
cost, err = s.billingService.CalculateCostUnified(CostInput{
|
||||
Ctx: ctx,
|
||||
@@ -8179,10 +8192,10 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
billingMode := cost.BillingMode
|
||||
usageLog.BillingMode = &billingMode
|
||||
} else if result.ImageCount > 0 {
|
||||
billingMode := "image"
|
||||
billingMode := string(BillingModeImage)
|
||||
usageLog.BillingMode = &billingMode
|
||||
} else {
|
||||
billingMode := "token"
|
||||
billingMode := string(BillingModeToken)
|
||||
usageLog.BillingMode = &billingMode
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,13 @@ import (
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
// PricingSource 定价来源标识
|
||||
const (
|
||||
PricingSourceChannel = "channel"
|
||||
PricingSourceLiteLLM = "litellm"
|
||||
PricingSourceFallback = "fallback"
|
||||
)
|
||||
|
||||
// ResolvedPricing 统一定价解析结果
|
||||
type ResolvedPricing struct {
|
||||
// Mode 计费模式
|
||||
@@ -78,9 +85,9 @@ func (r *ModelPricingResolver) resolveBasePricing(model string) (*ModelPricing,
|
||||
if err != nil {
|
||||
slog.Debug("failed to get model pricing from LiteLLM, using fallback",
|
||||
"model", model, "error", err)
|
||||
return nil, "fallback"
|
||||
return nil, PricingSourceFallback
|
||||
}
|
||||
return pricing, "litellm"
|
||||
return pricing, PricingSourceLiteLLM
|
||||
}
|
||||
|
||||
// applyChannelOverrides 应用渠道定价覆盖
|
||||
@@ -90,7 +97,7 @@ func (r *ModelPricingResolver) applyChannelOverrides(ctx context.Context, groupI
|
||||
return
|
||||
}
|
||||
|
||||
resolved.Source = "channel"
|
||||
resolved.Source = PricingSourceChannel
|
||||
resolved.Mode = chPricing.BillingMode
|
||||
if resolved.Mode == "" {
|
||||
resolved.Mode = BillingModeToken
|
||||
|
||||
@@ -4290,7 +4290,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
billingMode := cost.BillingMode
|
||||
usageLog.BillingMode = &billingMode
|
||||
} else {
|
||||
billingMode := "token"
|
||||
billingMode := string(BillingModeToken)
|
||||
usageLog.BillingMode = &billingMode
|
||||
}
|
||||
// 添加 UserAgent
|
||||
|
||||
Reference in New Issue
Block a user