refactor(channels): centralize BillingModelSource normalization and exhaustive enum maps
- service: add normalizeBillingModelSource helper, apply in Create/GetByID/Update/List/ListAvailable outputs - handler: drop channelToResponse fallback now that service owns the default; add passthrough test - frontend: replace ternary status/billing-source lookups with Record<Enum, ...> maps so new union members fail the build - chip/table: drop local type aliases, reuse UserSupportedModel/UserPricingInterval directly - tests: assert short-circuit on ListAll error, wrap-prefix preservation, and Name-based default lookup
This commit is contained in:
@@ -32,6 +32,9 @@ type AvailableChannel struct {
|
||||
// 支持模型通过 (*Channel).SupportedModels() 计算得到(见 channel.go)。
|
||||
// 关联分组信息通过 groupRepo.ListActive 查询后按 ID 映射;渠道 GroupIDs 中未在活跃列表中
|
||||
// 的分组(已停用或删除)会被忽略。
|
||||
//
|
||||
// 前置条件:s.groupRepo 必须非 nil(由 wire DI 保证)。直接 nil-deref 用于 fail-fast,
|
||||
// 避免静默掩盖注入缺失。
|
||||
func (s *ChannelService) ListAvailable(ctx context.Context) ([]AvailableChannel, error) {
|
||||
channels, err := s.repo.ListAll(ctx)
|
||||
if err != nil {
|
||||
@@ -61,19 +64,16 @@ func (s *ChannelService) ListAvailable(ctx context.Context) ([]AvailableChannel,
|
||||
groups = append(groups, ref)
|
||||
}
|
||||
}
|
||||
sort.Slice(groups, func(i, j int) bool { return groups[i].Name < groups[j].Name })
|
||||
sort.SliceStable(groups, func(i, j int) bool { return groups[i].Name < groups[j].Name })
|
||||
|
||||
billingSource := ch.BillingModelSource
|
||||
if billingSource == "" {
|
||||
billingSource = BillingModelSourceChannelMapped
|
||||
}
|
||||
normalizeBillingModelSource(ch)
|
||||
|
||||
out = append(out, AvailableChannel{
|
||||
ID: ch.ID,
|
||||
Name: ch.Name,
|
||||
Description: ch.Description,
|
||||
Status: ch.Status,
|
||||
BillingModelSource: billingSource,
|
||||
BillingModelSource: ch.BillingModelSource,
|
||||
RestrictModels: ch.RestrictModels,
|
||||
Groups: groups,
|
||||
SupportedModels: ch.SupportedModels(),
|
||||
|
||||
@@ -14,12 +14,15 @@ import (
|
||||
// stubGroupRepoForAvailable 是 ListAvailable 测试用的 GroupRepository stub,
|
||||
// 仅实现 ListActive;其他方法对本测试无关,返回零值即可。
|
||||
// listActiveErr 非 nil 时,ListActive 返回该错误用于错误传播测试。
|
||||
// listActiveCalls 记录调用次数,用于断言「失败短路时不再访问 groupRepo」等行为。
|
||||
type stubGroupRepoForAvailable struct {
|
||||
activeGroups []Group
|
||||
listActiveErr error
|
||||
activeGroups []Group
|
||||
listActiveErr error
|
||||
listActiveCalls int
|
||||
}
|
||||
|
||||
func (s *stubGroupRepoForAvailable) ListActive(ctx context.Context) ([]Group, error) {
|
||||
s.listActiveCalls++
|
||||
if s.listActiveErr != nil {
|
||||
return nil, s.listActiveErr
|
||||
}
|
||||
@@ -125,15 +128,18 @@ func TestListAvailable_SortedByName(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestListAvailable_ListAllErrorPropagates(t *testing.T) {
|
||||
// ListAll 返回错误时 ListAvailable 应直接返回包装后的错误,不再访问 groupRepo。
|
||||
// ListAll 返回错误时 ListAvailable 应直接返回包装后的错误,且不再访问 groupRepo(短路)。
|
||||
sentinel := errors.New("list-all-boom")
|
||||
repo := &mockChannelRepository{
|
||||
listAllFn: func(ctx context.Context) ([]Channel, error) { return nil, sentinel },
|
||||
}
|
||||
svc := NewChannelService(repo, &stubGroupRepoForAvailable{}, nil)
|
||||
groupRepo := &stubGroupRepoForAvailable{}
|
||||
svc := NewChannelService(repo, groupRepo, nil)
|
||||
out, err := svc.ListAvailable(context.Background())
|
||||
require.Nil(t, out)
|
||||
require.ErrorIs(t, err, sentinel)
|
||||
require.Contains(t, err.Error(), "list channels", "wrap 前缀缺失,可能 %w 被改为 %v")
|
||||
require.Equal(t, 0, groupRepo.listActiveCalls, "ListAll 失败后不应再调用 groupRepo.ListActive")
|
||||
}
|
||||
|
||||
func TestListAvailable_ListActiveErrorPropagates(t *testing.T) {
|
||||
@@ -146,6 +152,7 @@ func TestListAvailable_ListActiveErrorPropagates(t *testing.T) {
|
||||
out, err := svc.ListAvailable(context.Background())
|
||||
require.Nil(t, out)
|
||||
require.ErrorIs(t, err, sentinel)
|
||||
require.Contains(t, err.Error(), "list active groups", "wrap 前缀缺失,可能 %w 被改为 %v")
|
||||
}
|
||||
|
||||
func TestListAvailable_DefaultsEmptyBillingModelSource(t *testing.T) {
|
||||
@@ -159,6 +166,12 @@ func TestListAvailable_DefaultsEmptyBillingModelSource(t *testing.T) {
|
||||
out, err := svc.ListAvailable(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Len(t, out, 2)
|
||||
require.Equal(t, BillingModelSourceChannelMapped, out[0].BillingModelSource)
|
||||
require.Equal(t, BillingModelSourceUpstream, out[1].BillingModelSource)
|
||||
|
||||
// 按 Name 查找,避免依赖排序副作用。
|
||||
byName := make(map[string]string, len(out))
|
||||
for _, ch := range out {
|
||||
byName[ch.Name] = ch.BillingModelSource
|
||||
}
|
||||
require.Equal(t, BillingModelSourceChannelMapped, byName["empty"])
|
||||
require.Equal(t, BillingModelSourceUpstream, byName["explicit"])
|
||||
}
|
||||
|
||||
@@ -686,9 +686,7 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
|
||||
ApplyPricingToAccountStats: input.ApplyPricingToAccountStats,
|
||||
AccountStatsPricingRules: input.AccountStatsPricingRules,
|
||||
}
|
||||
if channel.BillingModelSource == "" {
|
||||
channel.BillingModelSource = BillingModelSourceChannelMapped
|
||||
}
|
||||
normalizeBillingModelSource(channel)
|
||||
|
||||
if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil {
|
||||
return nil, err
|
||||
@@ -704,12 +702,31 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
|
||||
}
|
||||
|
||||
s.invalidateCache()
|
||||
return s.repo.GetByID(ctx, channel.ID)
|
||||
created, err := s.repo.GetByID(ctx, channel.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
normalizeBillingModelSource(created)
|
||||
return created, nil
|
||||
}
|
||||
|
||||
// GetByID 获取渠道详情
|
||||
// GetByID 获取渠道详情。返回前统一把空 BillingModelSource 回填为 ChannelMapped,
|
||||
// 让所有 handler 无需重复处理历史空值。
|
||||
func (s *ChannelService) GetByID(ctx context.Context, id int64) (*Channel, error) {
|
||||
return s.repo.GetByID(ctx, id)
|
||||
ch, err := s.repo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
normalizeBillingModelSource(ch)
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
// normalizeBillingModelSource 若 BillingModelSource 为空则回填默认值 ChannelMapped。
|
||||
// 统一在 service 层完成,避免 handler 响应层重复兜底。
|
||||
func normalizeBillingModelSource(ch *Channel) {
|
||||
if ch != nil && ch.BillingModelSource == "" {
|
||||
ch.BillingModelSource = BillingModelSourceChannelMapped
|
||||
}
|
||||
}
|
||||
|
||||
// Update 更新渠道
|
||||
@@ -741,7 +758,12 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan
|
||||
s.invalidateCache()
|
||||
s.invalidateAuthCacheForGroups(ctx, oldGroupIDs, channel.GroupIDs)
|
||||
|
||||
return s.repo.GetByID(ctx, id)
|
||||
updated, err := s.repo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
normalizeBillingModelSource(updated)
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
// applyUpdateInput 将更新请求的字段应用到渠道实体上。
|
||||
@@ -859,7 +881,14 @@ func (s *ChannelService) Delete(ctx context.Context, id int64) error {
|
||||
|
||||
// List 获取渠道列表
|
||||
func (s *ChannelService) List(ctx context.Context, params pagination.PaginationParams, status, search string) ([]Channel, *pagination.PaginationResult, error) {
|
||||
return s.repo.List(ctx, params, status, search)
|
||||
channels, res, err := s.repo.List(ctx, params, status, search)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
for i := range channels {
|
||||
normalizeBillingModelSource(&channels[i])
|
||||
}
|
||||
return channels, res, nil
|
||||
}
|
||||
|
||||
// modelEntry 表示一个模型模式条目(用于冲突检测)
|
||||
|
||||
Reference in New Issue
Block a user