diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 7c26a47c..701f3659 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -586,6 +586,15 @@ func (s *adminServiceImpl) assignDefaultSubscriptions(ctx context.Context, userI } func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) { + // 校验用户专属分组倍率:必须 > 0(nil 合法,表示清除专属倍率) + if input.GroupRates != nil { + for groupID, rate := range input.GroupRates { + if rate != nil && *rate <= 0 { + return nil, fmt.Errorf("rate_multiplier must be > 0 (group_id=%d)", groupID) + } + } + } + user, err := s.userRepo.GetByID(ctx, id) if err != nil { return nil, err @@ -811,6 +820,10 @@ func (s *adminServiceImpl) GetGroup(ctx context.Context, id int64) (*Group, erro } func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error) { + if input.RateMultiplier <= 0 { + return nil, errors.New("rate_multiplier must be > 0") + } + platform := input.Platform if platform == "" { platform = PlatformAnthropic @@ -1050,6 +1063,9 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd group.Platform = input.Platform } if input.RateMultiplier != nil { + if *input.RateMultiplier <= 0 { + return nil, errors.New("rate_multiplier must be > 0") + } group.RateMultiplier = *input.RateMultiplier } if input.IsExclusive != nil { @@ -1286,6 +1302,11 @@ func (s *adminServiceImpl) BatchSetGroupRateMultipliers(ctx context.Context, gro if s.userGroupRateRepo == nil { return nil } + for _, e := range entries { + if e.RateMultiplier <= 0 { + return fmt.Errorf("rate_multiplier must be > 0 (user_id=%d)", e.UserID) + } + } return s.userGroupRateRepo.SyncGroupRateMultipliers(ctx, groupID, entries) } diff --git a/backend/internal/service/admin_service_group_test.go b/backend/internal/service/admin_service_group_test.go index a4c6d0ca..41d2c26a 100644 --- a/backend/internal/service/admin_service_group_test.go +++ b/backend/internal/service/admin_service_group_test.go @@ -621,6 +621,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsUnsupportedPlatfo _, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ Name: "g1", Platform: PlatformOpenAI, + RateMultiplier: 1.0, SubscriptionType: SubscriptionTypeStandard, FallbackGroupIDOnInvalidRequest: &fallbackID, }) @@ -641,6 +642,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsSubscription(t *t _, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ Name: "g1", Platform: PlatformAnthropic, + RateMultiplier: 1.0, SubscriptionType: SubscriptionTypeSubscription, FallbackGroupIDOnInvalidRequest: &fallbackID, }) @@ -695,6 +697,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsFallbackGroup(t * _, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ Name: "g1", Platform: PlatformAnthropic, + RateMultiplier: 1.0, SubscriptionType: SubscriptionTypeStandard, FallbackGroupIDOnInvalidRequest: &fallbackID, }) @@ -713,6 +716,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackNotFound(t *testing.T) { _, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ Name: "g1", Platform: PlatformAnthropic, + RateMultiplier: 1.0, SubscriptionType: SubscriptionTypeStandard, FallbackGroupIDOnInvalidRequest: &fallbackID, }) @@ -733,6 +737,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackAllowsAntigravity(t *tes group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ Name: "g1", Platform: PlatformAntigravity, + RateMultiplier: 1.0, SubscriptionType: SubscriptionTypeStandard, FallbackGroupIDOnInvalidRequest: &fallbackID, }) @@ -750,6 +755,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackClearsOnZero(t *testing. group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ Name: "g1", Platform: PlatformAnthropic, + RateMultiplier: 1.0, SubscriptionType: SubscriptionTypeStandard, FallbackGroupIDOnInvalidRequest: &zero, }) diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go index 32a54cbe..c9f32b3b 100644 --- a/backend/internal/service/billing_service.go +++ b/backend/internal/service/billing_service.go @@ -448,8 +448,9 @@ func (s *BillingService) CalculateCostUnified(input CostInput) (*CostBreakdown, }) } - if input.RateMultiplier <= 0 { - input.RateMultiplier = 1.0 + // 保存时强制 > 0;若仍有负数泄漏(缓存/迁移残留),按 0 处理避免按 1x 误扣。 + if input.RateMultiplier < 0 { + input.RateMultiplier = 0 } var breakdown *CostBreakdown @@ -493,8 +494,9 @@ func (s *BillingService) computeTokenBreakdown( rateMultiplier float64, serviceTier string, applyLongCtx bool, ) *CostBreakdown { - if rateMultiplier <= 0 { - rateMultiplier = 1.0 + // 保存时强制 > 0;若仍有负数泄漏,按 0 处理避免按 1x 误扣。 + if rateMultiplier < 0 { + rateMultiplier = 0 } inputPrice := pricing.InputPricePerToken @@ -831,9 +833,9 @@ func (s *BillingService) CalculateImageCost(model string, imageSize string, imag // 计算总费用 totalCost := unitPrice * float64(imageCount) - // 应用倍率 - if rateMultiplier <= 0 { - rateMultiplier = 1.0 + // 应用倍率(保存时强制 > 0;负数按 0 处理避免按 1x 误扣) + if rateMultiplier < 0 { + rateMultiplier = 0 } actualCost := totalCost * rateMultiplier diff --git a/backend/internal/service/billing_service_image_test.go b/backend/internal/service/billing_service_image_test.go index fa90f6bb..8d3ca987 100644 --- a/backend/internal/service/billing_service_image_test.go +++ b/backend/internal/service/billing_service_image_test.go @@ -90,13 +90,14 @@ func TestCalculateImageCost_NegativeCount(t *testing.T) { require.Equal(t, 0.0, cost.ActualCost) } -// TestCalculateImageCost_ZeroRateMultiplier 测试费率倍数为 0 时默认使用 1.0 +// TestCalculateImageCost_ZeroRateMultiplier 锁定新行为:倍率 0 直接按 0 计费 +// (保存时已强制 > 0;若仍有 0 泄漏到计费层,零消耗比历史的 1.0 更安全)。 func TestCalculateImageCost_ZeroRateMultiplier(t *testing.T) { svc := &BillingService{} cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 0) require.InDelta(t, 0.201, cost.TotalCost, 0.0001) - require.InDelta(t, 0.201, cost.ActualCost, 0.0001) // 0 倍率当作 1.0 处理 + require.InDelta(t, 0.0, cost.ActualCost, 1e-10) } // TestGetImageUnitPrice_GroupPriorityOverDefault 测试分组价格优先于默认价格 diff --git a/backend/internal/service/billing_service_rate_multiplier_test.go b/backend/internal/service/billing_service_rate_multiplier_test.go new file mode 100644 index 00000000..83788196 --- /dev/null +++ b/backend/internal/service/billing_service_rate_multiplier_test.go @@ -0,0 +1,63 @@ +//go:build unit + +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// TestCalculateCost_RateMultiplier_NegativeClampedToZero 锁定负数倍率被 +// 钳制为 0(而非历史上的 1.0),避免配置异常导致静默按标准价扣费。 +func TestCalculateCost_RateMultiplier_NegativeClampedToZero(t *testing.T) { + svc := newTestBillingService() + tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500} + + tests := []struct { + name string + multiplier float64 + wantRatio float64 // ActualCost / TotalCost + }{ + {"negative clamped to 0", -1.5, 0}, + {"zero passes through as 0 (defense in depth)", 0, 0}, + {"positive 2x applied", 2.0, 2.0}, + {"positive 0.5x applied", 0.5, 0.5}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cost, err := svc.CalculateCost("claude-sonnet-4", tokens, tt.multiplier) + require.NoError(t, err) + require.Greater(t, cost.TotalCost, 0.0, "TotalCost should be non-zero") + require.InDelta(t, tt.wantRatio*cost.TotalCost, cost.ActualCost, 1e-9) + }) + } +} + +// TestCalculateImageCost_RateMultiplier_NegativeClampedToZero 图片按次计费路径 +// 同样遵循"负数 → 0"语义。 +func TestCalculateImageCost_RateMultiplier_NegativeClampedToZero(t *testing.T) { + svc := newTestBillingService() + price := 0.04 + cfg := &ImagePriceConfig{Price1K: &price} + + tests := []struct { + name string + multiplier float64 + wantRatio float64 + }{ + {"negative clamped to 0", -0.5, 0}, + {"zero passes through", 0, 0}, + {"positive 3x applied", 3.0, 3.0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cost := svc.CalculateImageCost("imagen-3", "1K", 2, cfg, tt.multiplier) + require.NotNil(t, cost) + require.Greater(t, cost.TotalCost, 0.0) + require.InDelta(t, tt.wantRatio*cost.TotalCost, cost.ActualCost, 1e-9) + }) + } +} diff --git a/backend/internal/service/billing_service_test.go b/backend/internal/service/billing_service_test.go index 2cf134e2..fc8361c7 100644 --- a/backend/internal/service/billing_service_test.go +++ b/backend/internal/service/billing_service_test.go @@ -71,34 +71,6 @@ func TestCalculateCost_RateMultiplier(t *testing.T) { require.InDelta(t, cost1x.ActualCost*2, cost2x.ActualCost, 1e-10) } -func TestCalculateCost_ZeroMultiplierDefaultsToOne(t *testing.T) { - svc := newTestBillingService() - - tokens := UsageTokens{InputTokens: 1000} - - costZero, err := svc.CalculateCost("claude-sonnet-4", tokens, 0) - require.NoError(t, err) - - costOne, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) - require.NoError(t, err) - - require.InDelta(t, costOne.ActualCost, costZero.ActualCost, 1e-10) -} - -func TestCalculateCost_NegativeMultiplierDefaultsToOne(t *testing.T) { - svc := newTestBillingService() - - tokens := UsageTokens{InputTokens: 1000} - - costNeg, err := svc.CalculateCost("claude-sonnet-4", tokens, -1.0) - require.NoError(t, err) - - costOne, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) - require.NoError(t, err) - - require.InDelta(t, costOne.ActualCost, costNeg.ActualCost, 1e-10) -} - func TestGetModelPricing_FallbackMatchesByFamily(t *testing.T) { svc := newTestBillingService() diff --git a/backend/internal/service/billing_service_unified_test.go b/backend/internal/service/billing_service_unified_test.go index 694c3384..e6a92d1a 100644 --- a/backend/internal/service/billing_service_unified_test.go +++ b/backend/internal/service/billing_service_unified_test.go @@ -147,40 +147,35 @@ func TestCalculateCostUnified_ImageMode(t *testing.T) { require.Equal(t, string(BillingModeImage), cost.BillingMode) } -func TestCalculateCostUnified_RateMultiplierZeroDefaultsToOne(t *testing.T) { +// TestCalculateCostUnified_RateMultiplierZeroProducesZero 锁定新行为: +// 保存时强制 > 0;若 0 仍泄漏到计费层,按 0 计费(而非历史上的 1.0)。 +func TestCalculateCostUnified_RateMultiplierZeroProducesZero(t *testing.T) { bs := newTestBillingService() resolver := NewModelPricingResolver(nil, bs) tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500} - costZero, err := bs.CalculateCostUnified(CostInput{ + cost, err := bs.CalculateCostUnified(CostInput{ Ctx: context.Background(), Model: "claude-sonnet-4", Tokens: tokens, - RateMultiplier: 0, // should default to 1.0 + RateMultiplier: 0, Resolver: resolver, }) require.NoError(t, err) - - costOne, err := bs.CalculateCostUnified(CostInput{ - Ctx: context.Background(), - Model: "claude-sonnet-4", - Tokens: tokens, - RateMultiplier: 1.0, - Resolver: resolver, - }) - require.NoError(t, err) - - require.InDelta(t, costOne.ActualCost, costZero.ActualCost, 1e-10) + require.Greater(t, cost.TotalCost, 0.0) + require.InDelta(t, 0.0, cost.ActualCost, 1e-10) } -func TestCalculateCostUnified_NegativeRateMultiplierDefaultsToOne(t *testing.T) { +// TestCalculateCostUnified_NegativeRateMultiplierClampedToZero 锁定新行为: +// 负数倍率按 0 计费,避免历史的 <=0 → 1.0 把配置异常静默按标准价扣费。 +func TestCalculateCostUnified_NegativeRateMultiplierClampedToZero(t *testing.T) { bs := newTestBillingService() resolver := NewModelPricingResolver(nil, bs) tokens := UsageTokens{InputTokens: 1000} - costNeg, err := bs.CalculateCostUnified(CostInput{ + cost, err := bs.CalculateCostUnified(CostInput{ Ctx: context.Background(), Model: "claude-sonnet-4", Tokens: tokens, @@ -188,17 +183,8 @@ func TestCalculateCostUnified_NegativeRateMultiplierDefaultsToOne(t *testing.T) Resolver: resolver, }) require.NoError(t, err) - - costOne, err := bs.CalculateCostUnified(CostInput{ - Ctx: context.Background(), - Model: "claude-sonnet-4", - Tokens: tokens, - RateMultiplier: 1.0, - Resolver: resolver, - }) - require.NoError(t, err) - - require.InDelta(t, costOne.ActualCost, costNeg.ActualCost, 1e-10) + require.Greater(t, cost.TotalCost, 0.0) + require.InDelta(t, 0.0, cost.ActualCost, 1e-10) } func TestCalculateCostUnified_BillingModeFieldFilled(t *testing.T) { diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go index 12262613..64434ae1 100644 --- a/backend/internal/service/group.go +++ b/backend/internal/service/group.go @@ -76,10 +76,6 @@ func (g *Group) IsSubscriptionType() bool { return g.SubscriptionType == SubscriptionTypeSubscription } -func (g *Group) IsFreeSubscription() bool { - return g.IsSubscriptionType() && g.RateMultiplier == 0 -} - func (g *Group) HasDailyLimit() bool { return g.DailyLimitUSD != nil && *g.DailyLimitUSD > 0 } diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go index e6fa94aa..6fa8a5bd 100644 --- a/backend/internal/service/openai_gateway_record_usage_test.go +++ b/backend/internal/service/openai_gateway_record_usage_test.go @@ -1031,7 +1031,7 @@ func TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSetsSubscriptionFiel Model: "gpt-5.1", Duration: time.Second, }, - APIKey: &APIKey{ID: 100, GroupID: i64p(88), Group: &Group{ID: 88, SubscriptionType: SubscriptionTypeSubscription}}, + APIKey: &APIKey{ID: 100, GroupID: i64p(88), Group: &Group{ID: 88, SubscriptionType: SubscriptionTypeSubscription, RateMultiplier: 1.0}}, User: &User{ID: 200}, Account: &Account{ID: 300}, Subscription: subscription, diff --git a/frontend/src/components/admin/group/GroupRateMultipliersModal.vue b/frontend/src/components/admin/group/GroupRateMultipliersModal.vue index bf79bea2..41b2e63c 100644 --- a/frontend/src/components/admin/group/GroupRateMultipliersModal.vue +++ b/frontend/src/components/admin/group/GroupRateMultipliersModal.vue @@ -166,7 +166,7 @@