refactor(channels): centralize BillingModelSource normalization and exhaustive enum maps
- service: add normalizeBillingModelSource helper, apply in Create/GetByID/Update/List/ListAvailable outputs - handler: drop channelToResponse fallback now that service owns the default; add passthrough test - frontend: replace ternary status/billing-source lookups with Record<Enum, ...> maps so new union members fail the build - chip/table: drop local type aliases, reuse UserSupportedModel/UserPricingInterval directly - tests: assert short-circuit on ListAll error, wrap-prefix preservation, and Name-based default lookup
This commit is contained in:
@@ -686,9 +686,7 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
|
||||
ApplyPricingToAccountStats: input.ApplyPricingToAccountStats,
|
||||
AccountStatsPricingRules: input.AccountStatsPricingRules,
|
||||
}
|
||||
if channel.BillingModelSource == "" {
|
||||
channel.BillingModelSource = BillingModelSourceChannelMapped
|
||||
}
|
||||
normalizeBillingModelSource(channel)
|
||||
|
||||
if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil {
|
||||
return nil, err
|
||||
@@ -704,12 +702,31 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
|
||||
}
|
||||
|
||||
s.invalidateCache()
|
||||
return s.repo.GetByID(ctx, channel.ID)
|
||||
created, err := s.repo.GetByID(ctx, channel.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
normalizeBillingModelSource(created)
|
||||
return created, nil
|
||||
}
|
||||
|
||||
// GetByID 获取渠道详情
|
||||
// GetByID 获取渠道详情。返回前统一把空 BillingModelSource 回填为 ChannelMapped,
|
||||
// 让所有 handler 无需重复处理历史空值。
|
||||
func (s *ChannelService) GetByID(ctx context.Context, id int64) (*Channel, error) {
|
||||
return s.repo.GetByID(ctx, id)
|
||||
ch, err := s.repo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
normalizeBillingModelSource(ch)
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
// normalizeBillingModelSource 若 BillingModelSource 为空则回填默认值 ChannelMapped。
|
||||
// 统一在 service 层完成,避免 handler 响应层重复兜底。
|
||||
func normalizeBillingModelSource(ch *Channel) {
|
||||
if ch != nil && ch.BillingModelSource == "" {
|
||||
ch.BillingModelSource = BillingModelSourceChannelMapped
|
||||
}
|
||||
}
|
||||
|
||||
// Update 更新渠道
|
||||
@@ -741,7 +758,12 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan
|
||||
s.invalidateCache()
|
||||
s.invalidateAuthCacheForGroups(ctx, oldGroupIDs, channel.GroupIDs)
|
||||
|
||||
return s.repo.GetByID(ctx, id)
|
||||
updated, err := s.repo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
normalizeBillingModelSource(updated)
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
// applyUpdateInput 将更新请求的字段应用到渠道实体上。
|
||||
@@ -859,7 +881,14 @@ func (s *ChannelService) Delete(ctx context.Context, id int64) error {
|
||||
|
||||
// List 获取渠道列表
|
||||
func (s *ChannelService) List(ctx context.Context, params pagination.PaginationParams, status, search string) ([]Channel, *pagination.PaginationResult, error) {
|
||||
return s.repo.List(ctx, params, status, search)
|
||||
channels, res, err := s.repo.List(ctx, params, status, search)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
for i := range channels {
|
||||
normalizeBillingModelSource(&channels[i])
|
||||
}
|
||||
return channels, res, nil
|
||||
}
|
||||
|
||||
// modelEntry 表示一个模型模式条目(用于冲突检测)
|
||||
|
||||
Reference in New Issue
Block a user