分组倍率和用户专属倍率在保存时没有校验,0 会触发计费层的 `<=0 → 1.0` 防御条款,结果订阅/余额分组按标准价扣费;完全是沉默地绕过了业务规则。 - 保存校验(admin_service):CreateGroup / UpdateGroup / BatchSetGroupRateMultipliers / UpdateUser.SyncUserGroupRates 全部要求 > 0 - 计算层(billing_service):三处 `<=0 → 1.0` 改为 `<0 → 0`;负数按 0 结算, 避免配置异常被静默按 1x 收费 - 前端:分组倍率 / 用户专属倍率输入 min 统一到 0.001 - 删除未使用的 IsFreeSubscription 方法 测试:新增 billing_service_rate_multiplier_test.go 端到端验证;更新原有锁定 旧 `<=0 → 1.0` 行为的测试。
245 lines
7.8 KiB
Go
245 lines
7.8 KiB
Go
//go:build unit
|
||
|
||
package service
|
||
|
||
import (
|
||
"context"
|
||
"testing"
|
||
"time"
|
||
|
||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||
"github.com/stretchr/testify/require"
|
||
)
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// CalculateCostUnified
|
||
// ---------------------------------------------------------------------------
|
||
|
||
func TestCalculateCostUnified_NilResolver_FallsBackToOldPath(t *testing.T) {
|
||
svc := newTestBillingService()
|
||
|
||
tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500}
|
||
input := CostInput{
|
||
Model: "claude-sonnet-4",
|
||
Tokens: tokens,
|
||
RateMultiplier: 1.0,
|
||
Resolver: nil, // no resolver
|
||
}
|
||
cost, err := svc.CalculateCostUnified(input)
|
||
require.NoError(t, err)
|
||
|
||
// Should match the old-path result exactly
|
||
expected, err := svc.calculateCostInternal("claude-sonnet-4", tokens, 1.0, "", nil)
|
||
require.NoError(t, err)
|
||
require.InDelta(t, expected.TotalCost, cost.TotalCost, 1e-10)
|
||
require.InDelta(t, expected.ActualCost, cost.ActualCost, 1e-10)
|
||
// BillingMode is NOT set by old path through CalculateCostUnified (resolver == nil)
|
||
require.Empty(t, cost.BillingMode)
|
||
}
|
||
|
||
func TestCalculateCostUnified_TokenMode(t *testing.T) {
|
||
bs := newTestBillingService()
|
||
resolver := NewModelPricingResolver(nil, bs)
|
||
|
||
tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500}
|
||
input := CostInput{
|
||
Ctx: context.Background(),
|
||
Model: "claude-sonnet-4",
|
||
Tokens: tokens,
|
||
RateMultiplier: 1.5,
|
||
Resolver: resolver,
|
||
}
|
||
cost, err := bs.CalculateCostUnified(input)
|
||
require.NoError(t, err)
|
||
require.NotNil(t, cost)
|
||
|
||
// Verify token billing: Input: 1000*3e-6=0.003, Output: 500*15e-6=0.0075
|
||
expectedTotal := 1000*3e-6 + 500*15e-6
|
||
require.InDelta(t, expectedTotal, cost.TotalCost, 1e-10)
|
||
require.InDelta(t, expectedTotal*1.5, cost.ActualCost, 1e-10)
|
||
require.Equal(t, string(BillingModeToken), cost.BillingMode)
|
||
}
|
||
|
||
func TestCalculateCostUnified_PerRequestMode(t *testing.T) {
|
||
// Set up a ChannelService with a per-request pricing channel
|
||
cs := newTestChannelServiceWithCache(t, &channelCache{
|
||
pricingByGroupModel: map[channelModelKey]*ChannelModelPricing{
|
||
{groupID: 1, model: "claude-sonnet-4"}: {
|
||
BillingMode: BillingModePerRequest,
|
||
PerRequestPrice: testPtrFloat64(0.05),
|
||
},
|
||
},
|
||
channelByGroupID: map[int64]*Channel{
|
||
1: {ID: 1, Status: StatusActive},
|
||
},
|
||
groupPlatform: map[int64]string{1: ""},
|
||
wildcardByGroupPlatform: map[channelGroupPlatformKey][]*wildcardPricingEntry{},
|
||
mappingByGroupModel: map[channelModelKey]string{},
|
||
wildcardMappingByGP: map[channelGroupPlatformKey][]*wildcardMappingEntry{},
|
||
byID: map[int64]*Channel{},
|
||
})
|
||
|
||
bs := newTestBillingService()
|
||
resolver := NewModelPricingResolver(cs, bs)
|
||
groupID := int64(1)
|
||
|
||
input := CostInput{
|
||
Ctx: context.Background(),
|
||
Model: "claude-sonnet-4",
|
||
GroupID: &groupID,
|
||
Tokens: UsageTokens{InputTokens: 100, OutputTokens: 50},
|
||
RequestCount: 3,
|
||
RateMultiplier: 2.0,
|
||
Resolver: resolver,
|
||
}
|
||
cost, err := bs.CalculateCostUnified(input)
|
||
require.NoError(t, err)
|
||
require.NotNil(t, cost)
|
||
|
||
// 3 requests * $0.05 = $0.15
|
||
require.InDelta(t, 0.15, cost.TotalCost, 1e-10)
|
||
// ActualCost = 0.15 * 2.0 = 0.30
|
||
require.InDelta(t, 0.30, cost.ActualCost, 1e-10)
|
||
require.Equal(t, string(BillingModePerRequest), cost.BillingMode)
|
||
}
|
||
|
||
func TestCalculateCostUnified_ImageMode(t *testing.T) {
|
||
cs := newTestChannelServiceWithCache(t, &channelCache{
|
||
pricingByGroupModel: map[channelModelKey]*ChannelModelPricing{
|
||
{groupID: 2, model: "gemini-image"}: {
|
||
BillingMode: BillingModeImage,
|
||
PerRequestPrice: testPtrFloat64(0.10),
|
||
},
|
||
},
|
||
channelByGroupID: map[int64]*Channel{
|
||
2: {ID: 2, Status: StatusActive},
|
||
},
|
||
groupPlatform: map[int64]string{2: ""},
|
||
wildcardByGroupPlatform: map[channelGroupPlatformKey][]*wildcardPricingEntry{},
|
||
mappingByGroupModel: map[channelModelKey]string{},
|
||
wildcardMappingByGP: map[channelGroupPlatformKey][]*wildcardMappingEntry{},
|
||
byID: map[int64]*Channel{},
|
||
})
|
||
|
||
bs := &BillingService{
|
||
cfg: &config.Config{},
|
||
fallbackPrices: map[string]*ModelPricing{},
|
||
}
|
||
resolver := NewModelPricingResolver(cs, bs)
|
||
groupID := int64(2)
|
||
|
||
input := CostInput{
|
||
Ctx: context.Background(),
|
||
Model: "gemini-image",
|
||
GroupID: &groupID,
|
||
Tokens: UsageTokens{},
|
||
RequestCount: 2,
|
||
RateMultiplier: 1.0,
|
||
Resolver: resolver,
|
||
}
|
||
cost, err := bs.CalculateCostUnified(input)
|
||
require.NoError(t, err)
|
||
require.NotNil(t, cost)
|
||
|
||
// 2 * $0.10 = $0.20
|
||
require.InDelta(t, 0.20, cost.TotalCost, 1e-10)
|
||
require.InDelta(t, 0.20, cost.ActualCost, 1e-10)
|
||
require.Equal(t, string(BillingModeImage), cost.BillingMode)
|
||
}
|
||
|
||
// TestCalculateCostUnified_RateMultiplierZeroProducesZero 锁定新行为:
|
||
// 保存时强制 > 0;若 0 仍泄漏到计费层,按 0 计费(而非历史上的 1.0)。
|
||
func TestCalculateCostUnified_RateMultiplierZeroProducesZero(t *testing.T) {
|
||
bs := newTestBillingService()
|
||
resolver := NewModelPricingResolver(nil, bs)
|
||
|
||
tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500}
|
||
|
||
cost, err := bs.CalculateCostUnified(CostInput{
|
||
Ctx: context.Background(),
|
||
Model: "claude-sonnet-4",
|
||
Tokens: tokens,
|
||
RateMultiplier: 0,
|
||
Resolver: resolver,
|
||
})
|
||
require.NoError(t, err)
|
||
require.Greater(t, cost.TotalCost, 0.0)
|
||
require.InDelta(t, 0.0, cost.ActualCost, 1e-10)
|
||
}
|
||
|
||
// TestCalculateCostUnified_NegativeRateMultiplierClampedToZero 锁定新行为:
|
||
// 负数倍率按 0 计费,避免历史的 <=0 → 1.0 把配置异常静默按标准价扣费。
|
||
func TestCalculateCostUnified_NegativeRateMultiplierClampedToZero(t *testing.T) {
|
||
bs := newTestBillingService()
|
||
resolver := NewModelPricingResolver(nil, bs)
|
||
|
||
tokens := UsageTokens{InputTokens: 1000}
|
||
|
||
cost, err := bs.CalculateCostUnified(CostInput{
|
||
Ctx: context.Background(),
|
||
Model: "claude-sonnet-4",
|
||
Tokens: tokens,
|
||
RateMultiplier: -5.0,
|
||
Resolver: resolver,
|
||
})
|
||
require.NoError(t, err)
|
||
require.Greater(t, cost.TotalCost, 0.0)
|
||
require.InDelta(t, 0.0, cost.ActualCost, 1e-10)
|
||
}
|
||
|
||
func TestCalculateCostUnified_BillingModeFieldFilled(t *testing.T) {
|
||
bs := newTestBillingService()
|
||
resolver := NewModelPricingResolver(nil, bs)
|
||
|
||
cost, err := bs.CalculateCostUnified(CostInput{
|
||
Ctx: context.Background(),
|
||
Model: "claude-sonnet-4",
|
||
Tokens: UsageTokens{InputTokens: 100},
|
||
RateMultiplier: 1.0,
|
||
Resolver: resolver,
|
||
})
|
||
require.NoError(t, err)
|
||
require.Equal(t, "token", cost.BillingMode)
|
||
}
|
||
|
||
func TestCalculateCostUnified_UsesPreResolvedPricing(t *testing.T) {
|
||
bs := newTestBillingService()
|
||
resolver := NewModelPricingResolver(nil, bs)
|
||
|
||
// Pre-resolve with per_request mode to verify it's used instead of re-resolving
|
||
preResolved := &ResolvedPricing{
|
||
Mode: BillingModePerRequest,
|
||
DefaultPerRequestPrice: 0.07,
|
||
}
|
||
|
||
cost, err := bs.CalculateCostUnified(CostInput{
|
||
Ctx: context.Background(),
|
||
Model: "claude-sonnet-4",
|
||
Tokens: UsageTokens{InputTokens: 100},
|
||
RequestCount: 2,
|
||
RateMultiplier: 1.0,
|
||
Resolver: resolver,
|
||
Resolved: preResolved,
|
||
})
|
||
require.NoError(t, err)
|
||
require.NotNil(t, cost)
|
||
|
||
// 2 * $0.07 = $0.14
|
||
require.InDelta(t, 0.14, cost.TotalCost, 1e-10)
|
||
require.Equal(t, string(BillingModePerRequest), cost.BillingMode)
|
||
}
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// helpers
|
||
// ---------------------------------------------------------------------------
|
||
|
||
// newTestChannelServiceWithCache creates a ChannelService with a pre-populated
|
||
// cache snapshot, bypassing the repository layer entirely.
|
||
func newTestChannelServiceWithCache(t *testing.T, cache *channelCache) *ChannelService {
|
||
t.Helper()
|
||
cs := &ChannelService{}
|
||
cache.loadedAt = time.Now()
|
||
cs.cache.Store(cache)
|
||
return cs
|
||
}
|