diff --git a/backend/internal/handler/admin/channel_handler.go b/backend/internal/handler/admin/channel_handler.go index 9151d018..950e6e72 100644 --- a/backend/internal/handler/admin/channel_handler.go +++ b/backend/internal/handler/admin/channel_handler.go @@ -158,9 +158,6 @@ func channelToResponse(ch *service.Channel) *channelResponse { UpdatedAt: ch.UpdatedAt.Format("2006-01-02T15:04:05Z"), } resp.BillingModelSource = ch.BillingModelSource - if resp.BillingModelSource == "" { - resp.BillingModelSource = service.BillingModelSourceChannelMapped - } if resp.GroupIDs == nil { resp.GroupIDs = []int64{} } diff --git a/backend/internal/handler/admin/channel_handler_test.go b/backend/internal/handler/admin/channel_handler_test.go index f218cce4..12cd4bdd 100644 --- a/backend/internal/handler/admin/channel_handler_test.go +++ b/backend/internal/handler/admin/channel_handler_test.go @@ -91,7 +91,7 @@ func TestChannelToResponse_EmptyDefaults(t *testing.T) { ch := &service.Channel{ ID: 1, Name: "ch", - BillingModelSource: "", + BillingModelSource: service.BillingModelSourceChannelMapped, CreatedAt: now, UpdatedAt: now, GroupIDs: nil, @@ -105,6 +105,9 @@ func TestChannelToResponse_EmptyDefaults(t *testing.T) { }, } + // handler 层 channelToResponse 现在是纯透传:BillingModelSource 的空值兜底 + // 已下放到 service 层(Create/GetByID/List/Update/ListAvailable 出口统一处理), + // 因此这里构造 fixture 时直接传入归一化后的值。 resp := channelToResponse(ch) require.Equal(t, "channel_mapped", resp.BillingModelSource) require.NotNil(t, resp.GroupIDs) @@ -117,6 +120,19 @@ func TestChannelToResponse_EmptyDefaults(t *testing.T) { require.Equal(t, "token", resp.ModelPricing[0].BillingMode) } +func TestChannelToResponse_BillingModelSourcePassthrough(t *testing.T) { + // handler 不再兜底 BillingModelSource:空值应原样透传(由 service 层负责默认回填)。 + ch := &service.Channel{ + ID: 1, + Name: "ch", + BillingModelSource: "", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + resp := channelToResponse(ch) + require.Equal(t, "", resp.BillingModelSource, "handler 应纯透传,默认值由 service.normalizeBillingModelSource 负责") +} + func TestChannelToResponse_NilModels(t *testing.T) { now := time.Now() ch := &service.Channel{ diff --git a/backend/internal/service/channel_available.go b/backend/internal/service/channel_available.go index 62406cd0..a162d81d 100644 --- a/backend/internal/service/channel_available.go +++ b/backend/internal/service/channel_available.go @@ -32,6 +32,9 @@ type AvailableChannel struct { // 支持模型通过 (*Channel).SupportedModels() 计算得到(见 channel.go)。 // 关联分组信息通过 groupRepo.ListActive 查询后按 ID 映射;渠道 GroupIDs 中未在活跃列表中 // 的分组(已停用或删除)会被忽略。 +// +// 前置条件:s.groupRepo 必须非 nil(由 wire DI 保证)。直接 nil-deref 用于 fail-fast, +// 避免静默掩盖注入缺失。 func (s *ChannelService) ListAvailable(ctx context.Context) ([]AvailableChannel, error) { channels, err := s.repo.ListAll(ctx) if err != nil { @@ -61,19 +64,16 @@ func (s *ChannelService) ListAvailable(ctx context.Context) ([]AvailableChannel, groups = append(groups, ref) } } - sort.Slice(groups, func(i, j int) bool { return groups[i].Name < groups[j].Name }) + sort.SliceStable(groups, func(i, j int) bool { return groups[i].Name < groups[j].Name }) - billingSource := ch.BillingModelSource - if billingSource == "" { - billingSource = BillingModelSourceChannelMapped - } + normalizeBillingModelSource(ch) out = append(out, AvailableChannel{ ID: ch.ID, Name: ch.Name, Description: ch.Description, Status: ch.Status, - BillingModelSource: billingSource, + BillingModelSource: ch.BillingModelSource, RestrictModels: ch.RestrictModels, Groups: groups, SupportedModels: ch.SupportedModels(), diff --git a/backend/internal/service/channel_available_test.go b/backend/internal/service/channel_available_test.go index 5da5e6e1..86bb4bb6 100644 --- a/backend/internal/service/channel_available_test.go +++ b/backend/internal/service/channel_available_test.go @@ -14,12 +14,15 @@ import ( // stubGroupRepoForAvailable 是 ListAvailable 测试用的 GroupRepository stub, // 仅实现 ListActive;其他方法对本测试无关,返回零值即可。 // listActiveErr 非 nil 时,ListActive 返回该错误用于错误传播测试。 +// listActiveCalls 记录调用次数,用于断言「失败短路时不再访问 groupRepo」等行为。 type stubGroupRepoForAvailable struct { - activeGroups []Group - listActiveErr error + activeGroups []Group + listActiveErr error + listActiveCalls int } func (s *stubGroupRepoForAvailable) ListActive(ctx context.Context) ([]Group, error) { + s.listActiveCalls++ if s.listActiveErr != nil { return nil, s.listActiveErr } @@ -125,15 +128,18 @@ func TestListAvailable_SortedByName(t *testing.T) { } func TestListAvailable_ListAllErrorPropagates(t *testing.T) { - // ListAll 返回错误时 ListAvailable 应直接返回包装后的错误,不再访问 groupRepo。 + // ListAll 返回错误时 ListAvailable 应直接返回包装后的错误,且不再访问 groupRepo(短路)。 sentinel := errors.New("list-all-boom") repo := &mockChannelRepository{ listAllFn: func(ctx context.Context) ([]Channel, error) { return nil, sentinel }, } - svc := NewChannelService(repo, &stubGroupRepoForAvailable{}, nil) + groupRepo := &stubGroupRepoForAvailable{} + svc := NewChannelService(repo, groupRepo, nil) out, err := svc.ListAvailable(context.Background()) require.Nil(t, out) require.ErrorIs(t, err, sentinel) + require.Contains(t, err.Error(), "list channels", "wrap 前缀缺失,可能 %w 被改为 %v") + require.Equal(t, 0, groupRepo.listActiveCalls, "ListAll 失败后不应再调用 groupRepo.ListActive") } func TestListAvailable_ListActiveErrorPropagates(t *testing.T) { @@ -146,6 +152,7 @@ func TestListAvailable_ListActiveErrorPropagates(t *testing.T) { out, err := svc.ListAvailable(context.Background()) require.Nil(t, out) require.ErrorIs(t, err, sentinel) + require.Contains(t, err.Error(), "list active groups", "wrap 前缀缺失,可能 %w 被改为 %v") } func TestListAvailable_DefaultsEmptyBillingModelSource(t *testing.T) { @@ -159,6 +166,12 @@ func TestListAvailable_DefaultsEmptyBillingModelSource(t *testing.T) { out, err := svc.ListAvailable(context.Background()) require.NoError(t, err) require.Len(t, out, 2) - require.Equal(t, BillingModelSourceChannelMapped, out[0].BillingModelSource) - require.Equal(t, BillingModelSourceUpstream, out[1].BillingModelSource) + + // 按 Name 查找,避免依赖排序副作用。 + byName := make(map[string]string, len(out)) + for _, ch := range out { + byName[ch.Name] = ch.BillingModelSource + } + require.Equal(t, BillingModelSourceChannelMapped, byName["empty"]) + require.Equal(t, BillingModelSourceUpstream, byName["explicit"]) } diff --git a/backend/internal/service/channel_service.go b/backend/internal/service/channel_service.go index 250df07b..4f22e205 100644 --- a/backend/internal/service/channel_service.go +++ b/backend/internal/service/channel_service.go @@ -686,9 +686,7 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput) ApplyPricingToAccountStats: input.ApplyPricingToAccountStats, AccountStatsPricingRules: input.AccountStatsPricingRules, } - if channel.BillingModelSource == "" { - channel.BillingModelSource = BillingModelSourceChannelMapped - } + normalizeBillingModelSource(channel) if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil { return nil, err @@ -704,12 +702,31 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput) } s.invalidateCache() - return s.repo.GetByID(ctx, channel.ID) + created, err := s.repo.GetByID(ctx, channel.ID) + if err != nil { + return nil, err + } + normalizeBillingModelSource(created) + return created, nil } -// GetByID 获取渠道详情 +// GetByID 获取渠道详情。返回前统一把空 BillingModelSource 回填为 ChannelMapped, +// 让所有 handler 无需重复处理历史空值。 func (s *ChannelService) GetByID(ctx context.Context, id int64) (*Channel, error) { - return s.repo.GetByID(ctx, id) + ch, err := s.repo.GetByID(ctx, id) + if err != nil { + return nil, err + } + normalizeBillingModelSource(ch) + return ch, nil +} + +// normalizeBillingModelSource 若 BillingModelSource 为空则回填默认值 ChannelMapped。 +// 统一在 service 层完成,避免 handler 响应层重复兜底。 +func normalizeBillingModelSource(ch *Channel) { + if ch != nil && ch.BillingModelSource == "" { + ch.BillingModelSource = BillingModelSourceChannelMapped + } } // Update 更新渠道 @@ -741,7 +758,12 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan s.invalidateCache() s.invalidateAuthCacheForGroups(ctx, oldGroupIDs, channel.GroupIDs) - return s.repo.GetByID(ctx, id) + updated, err := s.repo.GetByID(ctx, id) + if err != nil { + return nil, err + } + normalizeBillingModelSource(updated) + return updated, nil } // applyUpdateInput 将更新请求的字段应用到渠道实体上。 @@ -859,7 +881,14 @@ func (s *ChannelService) Delete(ctx context.Context, id int64) error { // List 获取渠道列表 func (s *ChannelService) List(ctx context.Context, params pagination.PaginationParams, status, search string) ([]Channel, *pagination.PaginationResult, error) { - return s.repo.List(ctx, params, status, search) + channels, res, err := s.repo.List(ctx, params, status, search) + if err != nil { + return nil, nil, err + } + for i := range channels { + normalizeBillingModelSource(&channels[i]) + } + return channels, res, nil } // modelEntry 表示一个模型模式条目(用于冲突检测) diff --git a/frontend/src/components/channels/AvailableChannelsTable.vue b/frontend/src/components/channels/AvailableChannelsTable.vue index e9011ec5..0bd19518 100644 --- a/frontend/src/components/channels/AvailableChannelsTable.vue +++ b/frontend/src/components/channels/AvailableChannelsTable.vue @@ -61,7 +61,7 @@ import { computed, useSlots } from 'vue' import DataTable from '@/components/common/DataTable.vue' import Icon from '@/components/icons/Icon.vue' import SupportedModelChip from './SupportedModelChip.vue' -import type { UserSupportedModelPricing } from '@/api/channels' +import type { UserSupportedModel } from '@/api/channels' interface GroupRef { id: number @@ -73,11 +73,8 @@ interface Row { name: string description?: string groups: GroupRef[] - supported_models: Array<{ - name: string - platform: string - pricing: UserSupportedModelPricing | null - }> + // 复用 user 侧最小 DTO;admin 侧 SupportedModel 结构上是其超集,可直接传入。 + supported_models: UserSupportedModel[] [key: string]: unknown } diff --git a/frontend/src/components/channels/SupportedModelChip.vue b/frontend/src/components/channels/SupportedModelChip.vue index f3e5549b..600e3ef5 100644 --- a/frontend/src/components/channels/SupportedModelChip.vue +++ b/frontend/src/components/channels/SupportedModelChip.vue @@ -127,19 +127,13 @@ import { BILLING_MODE_IMAGE, type BillingMode } from '@/constants/channel' +// 复用 api/channels.ts 的用户侧最小形态 DTO。 +// admin 侧 ChannelModelPricing 字段更多,但结构上是用户 DTO 的超集,admin 视图传入可直接通过结构化子类型检查。 import type { UserPricingInterval, UserSupportedModel } from '@/api/channels' -/** - * 复用 api/channels.ts 的用户侧最小形态 DTO。 - * admin 侧 ChannelModelPricing 字段更多,但结构上是用户 DTO 的超集, - * 因此 admin 视图传入时 TypeScript 结构化子类型会直接通过。 - */ -type PricingInterval = UserPricingInterval -type SupportedModelLike = UserSupportedModel - const props = withDefaults( defineProps<{ - model: SupportedModelLike + model: UserSupportedModel /** i18n 前缀:管理端传 `admin.availableChannels.pricing`,用户端传 `availableChannels.pricing`。 */ pricingKeyPrefix?: string noPricingLabel?: string @@ -180,7 +174,7 @@ function formatRange(min: number, max: number | null): string { return `(${min}, ${maxLabel}]` } -function formatInterval(iv: PricingInterval, mode: BillingMode): string { +function formatInterval(iv: UserPricingInterval, mode: BillingMode): string { if (mode === BILLING_MODE_PER_REQUEST || mode === BILLING_MODE_IMAGE) { return formatScaled(iv.per_request_price, 1) } diff --git a/frontend/src/views/admin/AvailableChannelsView.vue b/frontend/src/views/admin/AvailableChannelsView.vue index c7c27154..74e85618 100644 --- a/frontend/src/views/admin/AvailableChannelsView.vue +++ b/frontend/src/views/admin/AvailableChannelsView.vue @@ -46,20 +46,16 @@ @@ -78,7 +74,15 @@ import AvailableChannelsTable from '@/components/channels/AvailableChannelsTable import channelsAPI, { type AvailableChannel } from '@/api/admin/channels' import { useAppStore } from '@/stores/app' import { extractApiErrorMessage } from '@/utils/apiError' -import { CHANNEL_STATUS_ACTIVE, type ChannelStatus } from '@/constants/channel' +import { + CHANNEL_STATUS_ACTIVE, + CHANNEL_STATUS_DISABLED, + BILLING_MODEL_SOURCE_REQUESTED, + BILLING_MODEL_SOURCE_UPSTREAM, + BILLING_MODEL_SOURCE_CHANNEL_MAPPED, + type ChannelStatus, + type BillingModelSource +} from '@/constants/channel' const { t } = useI18n() const appStore = useAppStore() @@ -95,11 +99,30 @@ const columns = computed(() => [ { key: 'supported_models', label: t('admin.availableChannels.columns.supportedModels') } ]) -function statusLabel(status: ChannelStatus): string { - return status === CHANNEL_STATUS_ACTIVE - ? t('admin.availableChannels.statusActive') - : t('admin.availableChannels.statusDisabled') -} +/** + * 显示样式:i18n label + Tailwind class,按 ChannelStatus 完整穷举。 + * 用 Record 强制未来新增状态时 TS 编译失败,避免遗漏分支。 + */ +const statusStyles = computed>(() => ({ + [CHANNEL_STATUS_ACTIVE]: { + label: t('admin.availableChannels.statusActive'), + cls: 'bg-green-100 text-green-800 dark:bg-green-900/30 dark:text-green-400' + }, + [CHANNEL_STATUS_DISABLED]: { + label: t('admin.availableChannels.statusDisabled'), + cls: 'bg-gray-100 text-gray-600 dark:bg-dark-700 dark:text-gray-400' + } +})) + +/** + * BillingModelSource 显式映射:避免将后端 snake_case 字面量直接拼成 i18n key, + * 同时在 BillingModelSource 扩展时 TS 编译失败以暴露遗漏。 + */ +const billingSourceLabels = computed>(() => ({ + [BILLING_MODEL_SOURCE_REQUESTED]: t('admin.availableChannels.billingSource.requested'), + [BILLING_MODEL_SOURCE_UPSTREAM]: t('admin.availableChannels.billingSource.upstream'), + [BILLING_MODEL_SOURCE_CHANNEL_MAPPED]: t('admin.availableChannels.billingSource.channel_mapped') +})) const filteredChannels = computed(() => { const q = searchQuery.value.trim().toLowerCase()