fix(billing): reject rate_multiplier <= 0 on save; clamp negatives to 0 in compute
分组倍率和用户专属倍率在保存时没有校验,0 会触发计费层的 `<=0 → 1.0` 防御条款,结果订阅/余额分组按标准价扣费;完全是沉默地绕过了业务规则。 - 保存校验(admin_service):CreateGroup / UpdateGroup / BatchSetGroupRateMultipliers / UpdateUser.SyncUserGroupRates 全部要求 > 0 - 计算层(billing_service):三处 `<=0 → 1.0` 改为 `<0 → 0`;负数按 0 结算, 避免配置异常被静默按 1x 收费 - 前端:分组倍率 / 用户专属倍率输入 min 统一到 0.001 - 删除未使用的 IsFreeSubscription 方法 测试:新增 billing_service_rate_multiplier_test.go 端到端验证;更新原有锁定 旧 `<=0 → 1.0` 行为的测试。
This commit is contained in:
@@ -586,6 +586,15 @@ func (s *adminServiceImpl) assignDefaultSubscriptions(ctx context.Context, userI
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) {
|
||||
// 校验用户专属分组倍率:必须 > 0(nil 合法,表示清除专属倍率)
|
||||
if input.GroupRates != nil {
|
||||
for groupID, rate := range input.GroupRates {
|
||||
if rate != nil && *rate <= 0 {
|
||||
return nil, fmt.Errorf("rate_multiplier must be > 0 (group_id=%d)", groupID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
user, err := s.userRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -811,6 +820,10 @@ func (s *adminServiceImpl) GetGroup(ctx context.Context, id int64) (*Group, erro
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error) {
|
||||
if input.RateMultiplier <= 0 {
|
||||
return nil, errors.New("rate_multiplier must be > 0")
|
||||
}
|
||||
|
||||
platform := input.Platform
|
||||
if platform == "" {
|
||||
platform = PlatformAnthropic
|
||||
@@ -1050,6 +1063,9 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
||||
group.Platform = input.Platform
|
||||
}
|
||||
if input.RateMultiplier != nil {
|
||||
if *input.RateMultiplier <= 0 {
|
||||
return nil, errors.New("rate_multiplier must be > 0")
|
||||
}
|
||||
group.RateMultiplier = *input.RateMultiplier
|
||||
}
|
||||
if input.IsExclusive != nil {
|
||||
@@ -1286,6 +1302,11 @@ func (s *adminServiceImpl) BatchSetGroupRateMultipliers(ctx context.Context, gro
|
||||
if s.userGroupRateRepo == nil {
|
||||
return nil
|
||||
}
|
||||
for _, e := range entries {
|
||||
if e.RateMultiplier <= 0 {
|
||||
return fmt.Errorf("rate_multiplier must be > 0 (user_id=%d)", e.UserID)
|
||||
}
|
||||
}
|
||||
return s.userGroupRateRepo.SyncGroupRateMultipliers(ctx, groupID, entries)
|
||||
}
|
||||
|
||||
|
||||
@@ -621,6 +621,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsUnsupportedPlatfo
|
||||
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||
Name: "g1",
|
||||
Platform: PlatformOpenAI,
|
||||
RateMultiplier: 1.0,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
})
|
||||
@@ -641,6 +642,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsSubscription(t *t
|
||||
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||
Name: "g1",
|
||||
Platform: PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
SubscriptionType: SubscriptionTypeSubscription,
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
})
|
||||
@@ -695,6 +697,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsFallbackGroup(t *
|
||||
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||
Name: "g1",
|
||||
Platform: PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
})
|
||||
@@ -713,6 +716,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackNotFound(t *testing.T) {
|
||||
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||
Name: "g1",
|
||||
Platform: PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
})
|
||||
@@ -733,6 +737,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackAllowsAntigravity(t *tes
|
||||
group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||
Name: "g1",
|
||||
Platform: PlatformAntigravity,
|
||||
RateMultiplier: 1.0,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
})
|
||||
@@ -750,6 +755,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackClearsOnZero(t *testing.
|
||||
group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||
Name: "g1",
|
||||
Platform: PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
FallbackGroupIDOnInvalidRequest: &zero,
|
||||
})
|
||||
|
||||
@@ -448,8 +448,9 @@ func (s *BillingService) CalculateCostUnified(input CostInput) (*CostBreakdown,
|
||||
})
|
||||
}
|
||||
|
||||
if input.RateMultiplier <= 0 {
|
||||
input.RateMultiplier = 1.0
|
||||
// 保存时强制 > 0;若仍有负数泄漏(缓存/迁移残留),按 0 处理避免按 1x 误扣。
|
||||
if input.RateMultiplier < 0 {
|
||||
input.RateMultiplier = 0
|
||||
}
|
||||
|
||||
var breakdown *CostBreakdown
|
||||
@@ -493,8 +494,9 @@ func (s *BillingService) computeTokenBreakdown(
|
||||
rateMultiplier float64, serviceTier string,
|
||||
applyLongCtx bool,
|
||||
) *CostBreakdown {
|
||||
if rateMultiplier <= 0 {
|
||||
rateMultiplier = 1.0
|
||||
// 保存时强制 > 0;若仍有负数泄漏,按 0 处理避免按 1x 误扣。
|
||||
if rateMultiplier < 0 {
|
||||
rateMultiplier = 0
|
||||
}
|
||||
|
||||
inputPrice := pricing.InputPricePerToken
|
||||
@@ -831,9 +833,9 @@ func (s *BillingService) CalculateImageCost(model string, imageSize string, imag
|
||||
// 计算总费用
|
||||
totalCost := unitPrice * float64(imageCount)
|
||||
|
||||
// 应用倍率
|
||||
if rateMultiplier <= 0 {
|
||||
rateMultiplier = 1.0
|
||||
// 应用倍率(保存时强制 > 0;负数按 0 处理避免按 1x 误扣)
|
||||
if rateMultiplier < 0 {
|
||||
rateMultiplier = 0
|
||||
}
|
||||
actualCost := totalCost * rateMultiplier
|
||||
|
||||
|
||||
@@ -90,13 +90,14 @@ func TestCalculateImageCost_NegativeCount(t *testing.T) {
|
||||
require.Equal(t, 0.0, cost.ActualCost)
|
||||
}
|
||||
|
||||
// TestCalculateImageCost_ZeroRateMultiplier 测试费率倍数为 0 时默认使用 1.0
|
||||
// TestCalculateImageCost_ZeroRateMultiplier 锁定新行为:倍率 0 直接按 0 计费
|
||||
// (保存时已强制 > 0;若仍有 0 泄漏到计费层,零消耗比历史的 1.0 更安全)。
|
||||
func TestCalculateImageCost_ZeroRateMultiplier(t *testing.T) {
|
||||
svc := &BillingService{}
|
||||
|
||||
cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 0)
|
||||
require.InDelta(t, 0.201, cost.TotalCost, 0.0001)
|
||||
require.InDelta(t, 0.201, cost.ActualCost, 0.0001) // 0 倍率当作 1.0 处理
|
||||
require.InDelta(t, 0.0, cost.ActualCost, 1e-10)
|
||||
}
|
||||
|
||||
// TestGetImageUnitPrice_GroupPriorityOverDefault 测试分组价格优先于默认价格
|
||||
|
||||
@@ -0,0 +1,63 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestCalculateCost_RateMultiplier_NegativeClampedToZero 锁定负数倍率被
|
||||
// 钳制为 0(而非历史上的 1.0),避免配置异常导致静默按标准价扣费。
|
||||
func TestCalculateCost_RateMultiplier_NegativeClampedToZero(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
multiplier float64
|
||||
wantRatio float64 // ActualCost / TotalCost
|
||||
}{
|
||||
{"negative clamped to 0", -1.5, 0},
|
||||
{"zero passes through as 0 (defense in depth)", 0, 0},
|
||||
{"positive 2x applied", 2.0, 2.0},
|
||||
{"positive 0.5x applied", 0.5, 0.5},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cost, err := svc.CalculateCost("claude-sonnet-4", tokens, tt.multiplier)
|
||||
require.NoError(t, err)
|
||||
require.Greater(t, cost.TotalCost, 0.0, "TotalCost should be non-zero")
|
||||
require.InDelta(t, tt.wantRatio*cost.TotalCost, cost.ActualCost, 1e-9)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCalculateImageCost_RateMultiplier_NegativeClampedToZero 图片按次计费路径
|
||||
// 同样遵循"负数 → 0"语义。
|
||||
func TestCalculateImageCost_RateMultiplier_NegativeClampedToZero(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
price := 0.04
|
||||
cfg := &ImagePriceConfig{Price1K: &price}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
multiplier float64
|
||||
wantRatio float64
|
||||
}{
|
||||
{"negative clamped to 0", -0.5, 0},
|
||||
{"zero passes through", 0, 0},
|
||||
{"positive 3x applied", 3.0, 3.0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cost := svc.CalculateImageCost("imagen-3", "1K", 2, cfg, tt.multiplier)
|
||||
require.NotNil(t, cost)
|
||||
require.Greater(t, cost.TotalCost, 0.0)
|
||||
require.InDelta(t, tt.wantRatio*cost.TotalCost, cost.ActualCost, 1e-9)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -71,34 +71,6 @@ func TestCalculateCost_RateMultiplier(t *testing.T) {
|
||||
require.InDelta(t, cost1x.ActualCost*2, cost2x.ActualCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestCalculateCost_ZeroMultiplierDefaultsToOne(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
tokens := UsageTokens{InputTokens: 1000}
|
||||
|
||||
costZero, err := svc.CalculateCost("claude-sonnet-4", tokens, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
costOne, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.InDelta(t, costOne.ActualCost, costZero.ActualCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestCalculateCost_NegativeMultiplierDefaultsToOne(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
tokens := UsageTokens{InputTokens: 1000}
|
||||
|
||||
costNeg, err := svc.CalculateCost("claude-sonnet-4", tokens, -1.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
costOne, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.InDelta(t, costOne.ActualCost, costNeg.ActualCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestGetModelPricing_FallbackMatchesByFamily(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
|
||||
@@ -147,40 +147,35 @@ func TestCalculateCostUnified_ImageMode(t *testing.T) {
|
||||
require.Equal(t, string(BillingModeImage), cost.BillingMode)
|
||||
}
|
||||
|
||||
func TestCalculateCostUnified_RateMultiplierZeroDefaultsToOne(t *testing.T) {
|
||||
// TestCalculateCostUnified_RateMultiplierZeroProducesZero 锁定新行为:
|
||||
// 保存时强制 > 0;若 0 仍泄漏到计费层,按 0 计费(而非历史上的 1.0)。
|
||||
func TestCalculateCostUnified_RateMultiplierZeroProducesZero(t *testing.T) {
|
||||
bs := newTestBillingService()
|
||||
resolver := NewModelPricingResolver(nil, bs)
|
||||
|
||||
tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500}
|
||||
|
||||
costZero, err := bs.CalculateCostUnified(CostInput{
|
||||
cost, err := bs.CalculateCostUnified(CostInput{
|
||||
Ctx: context.Background(),
|
||||
Model: "claude-sonnet-4",
|
||||
Tokens: tokens,
|
||||
RateMultiplier: 0, // should default to 1.0
|
||||
RateMultiplier: 0,
|
||||
Resolver: resolver,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
costOne, err := bs.CalculateCostUnified(CostInput{
|
||||
Ctx: context.Background(),
|
||||
Model: "claude-sonnet-4",
|
||||
Tokens: tokens,
|
||||
RateMultiplier: 1.0,
|
||||
Resolver: resolver,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.InDelta(t, costOne.ActualCost, costZero.ActualCost, 1e-10)
|
||||
require.Greater(t, cost.TotalCost, 0.0)
|
||||
require.InDelta(t, 0.0, cost.ActualCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestCalculateCostUnified_NegativeRateMultiplierDefaultsToOne(t *testing.T) {
|
||||
// TestCalculateCostUnified_NegativeRateMultiplierClampedToZero 锁定新行为:
|
||||
// 负数倍率按 0 计费,避免历史的 <=0 → 1.0 把配置异常静默按标准价扣费。
|
||||
func TestCalculateCostUnified_NegativeRateMultiplierClampedToZero(t *testing.T) {
|
||||
bs := newTestBillingService()
|
||||
resolver := NewModelPricingResolver(nil, bs)
|
||||
|
||||
tokens := UsageTokens{InputTokens: 1000}
|
||||
|
||||
costNeg, err := bs.CalculateCostUnified(CostInput{
|
||||
cost, err := bs.CalculateCostUnified(CostInput{
|
||||
Ctx: context.Background(),
|
||||
Model: "claude-sonnet-4",
|
||||
Tokens: tokens,
|
||||
@@ -188,17 +183,8 @@ func TestCalculateCostUnified_NegativeRateMultiplierDefaultsToOne(t *testing.T)
|
||||
Resolver: resolver,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
costOne, err := bs.CalculateCostUnified(CostInput{
|
||||
Ctx: context.Background(),
|
||||
Model: "claude-sonnet-4",
|
||||
Tokens: tokens,
|
||||
RateMultiplier: 1.0,
|
||||
Resolver: resolver,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.InDelta(t, costOne.ActualCost, costNeg.ActualCost, 1e-10)
|
||||
require.Greater(t, cost.TotalCost, 0.0)
|
||||
require.InDelta(t, 0.0, cost.ActualCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestCalculateCostUnified_BillingModeFieldFilled(t *testing.T) {
|
||||
|
||||
@@ -76,10 +76,6 @@ func (g *Group) IsSubscriptionType() bool {
|
||||
return g.SubscriptionType == SubscriptionTypeSubscription
|
||||
}
|
||||
|
||||
func (g *Group) IsFreeSubscription() bool {
|
||||
return g.IsSubscriptionType() && g.RateMultiplier == 0
|
||||
}
|
||||
|
||||
func (g *Group) HasDailyLimit() bool {
|
||||
return g.DailyLimitUSD != nil && *g.DailyLimitUSD > 0
|
||||
}
|
||||
|
||||
@@ -1031,7 +1031,7 @@ func TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSetsSubscriptionFiel
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 100, GroupID: i64p(88), Group: &Group{ID: 88, SubscriptionType: SubscriptionTypeSubscription}},
|
||||
APIKey: &APIKey{ID: 100, GroupID: i64p(88), Group: &Group{ID: 88, SubscriptionType: SubscriptionTypeSubscription, RateMultiplier: 1.0}},
|
||||
User: &User{ID: 200},
|
||||
Account: &Account{ID: 300},
|
||||
Subscription: subscription,
|
||||
|
||||
Reference in New Issue
Block a user