Files
sub2api/backend/internal/handler/available_channel_handler.go
erio 800802b8aa feat(channels): explode available channels by platform + apply platform theme
Backend: one source channel → N output rows, one per platform that
has user-visible groups. Each row carries a single platform, so the
frontend can color/icon an entire row without mixing sources.

- userAvailableChannel: add Platform field
- new explodeChannelByPlatform helper; drop now-redundant
  collectGroupPlatforms

Frontend: use the row platform to drive theming and stop repeating
"ANTHROPIC" / "OPENAI" labels on every model chip.

- api/channels.ts: UserAvailableChannel.platform
- AvailableChannelsTable: name cell — PlatformBadge next to channel
  name (replaces the two-line name/description block; description
  moves to the badge's title tooltip); groups cell — each chip uses
  platformBadgeLightClass + PlatformIcon; model list passes
  show-platform=false + platform-hint to child chips
- SupportedModelChip: chip bg/border driven by platformBadgeClass,
  leading PlatformIcon; platform-hint fallback when model.platform
  missing
2026-04-21 18:47:54 +08:00

262 lines
9.0 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package handler
import (
"sort"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// AvailableChannelHandler 处理用户侧「可用渠道」查询。
//
// 用户侧接口委托 ChannelService.ListAvailable并在返回前做三层过滤
// 1. 行过滤:只保留状态为 Active 且与当前用户可访问分组有交集的渠道;
// 2. 分组过滤:渠道的 Groups 只保留用户可访问的那些;
// 3. 平台过滤:渠道的 SupportedModels 只保留平台在用户可见 Groups 中出现过的模型,
// 防止"渠道同时挂在 antigravity / anthropic 两个平台的分组上,用户只访问
// antigravity却看到 anthropic 模型"这类跨平台信息泄漏;
// 4. 字段白名单:仅返回用户需要的字段(省略 BillingModelSource / RestrictModels
// / 内部 ID / Status 等管理字段)。
type AvailableChannelHandler struct {
channelService *service.ChannelService
apiKeyService *service.APIKeyService
settingService *service.SettingService
}
// NewAvailableChannelHandler 创建用户侧可用渠道 handler。
func NewAvailableChannelHandler(
channelService *service.ChannelService,
apiKeyService *service.APIKeyService,
settingService *service.SettingService,
) *AvailableChannelHandler {
return &AvailableChannelHandler{
channelService: channelService,
apiKeyService: apiKeyService,
settingService: settingService,
}
}
// featureEnabled 返回 available-channels 开关是否启用。默认关闭opt-in
func (h *AvailableChannelHandler) featureEnabled(c *gin.Context) bool {
if h.settingService == nil {
return false
}
return h.settingService.GetAvailableChannelsRuntime(c.Request.Context()).Enabled
}
// userAvailableGroup 用户可见的分组概要(白名单字段)。
type userAvailableGroup struct {
ID int64 `json:"id"`
Name string `json:"name"`
Platform string `json:"platform"`
}
// userSupportedModelPricing 用户可见的定价字段白名单。
type userSupportedModelPricing struct {
BillingMode string `json:"billing_mode"`
InputPrice *float64 `json:"input_price"`
OutputPrice *float64 `json:"output_price"`
CacheWritePrice *float64 `json:"cache_write_price"`
CacheReadPrice *float64 `json:"cache_read_price"`
ImageOutputPrice *float64 `json:"image_output_price"`
PerRequestPrice *float64 `json:"per_request_price"`
Intervals []userPricingIntervalDTO `json:"intervals"`
}
// userPricingIntervalDTO 定价区间白名单(去掉内部 ID、SortOrder 等前端不渲染的字段)。
type userPricingIntervalDTO struct {
MinTokens int `json:"min_tokens"`
MaxTokens *int `json:"max_tokens"`
TierLabel string `json:"tier_label,omitempty"`
InputPrice *float64 `json:"input_price"`
OutputPrice *float64 `json:"output_price"`
CacheWritePrice *float64 `json:"cache_write_price"`
CacheReadPrice *float64 `json:"cache_read_price"`
PerRequestPrice *float64 `json:"per_request_price"`
}
// userSupportedModel 用户可见的支持模型条目。
type userSupportedModel struct {
Name string `json:"name"`
Platform string `json:"platform"`
Pricing *userSupportedModelPricing `json:"pricing"`
}
// userAvailableChannel 用户可见的渠道条目(白名单字段)。
//
// 同一个渠道若在多个平台上都有用户可见的分组,会被摊开成多条记录 —— 每条对应
// 一个平台groups 和 supported_models 都只包含该平台的内容。这样前端无需在
// 一行内混排多平台信息,也能直接为整行应用平台色/图标。
type userAvailableChannel struct {
Name string `json:"name"`
Description string `json:"description"`
Platform string `json:"platform"`
Groups []userAvailableGroup `json:"groups"`
SupportedModels []userSupportedModel `json:"supported_models"`
}
// List 列出当前用户可见的「可用渠道」。
// GET /api/v1/channels/available
func (h *AvailableChannelHandler) List(c *gin.Context) {
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
// Feature 未启用时返回空数组(不暴露渠道信息)。检查放在认证之后,
// 保持与未开关前的 401 行为一致:未登录先 401登录后再按开关决定。
if !h.featureEnabled(c) {
response.Success(c, []userAvailableChannel{})
return
}
userGroups, err := h.apiKeyService.GetAvailableGroups(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
allowedGroupIDs := make(map[int64]struct{}, len(userGroups))
for i := range userGroups {
allowedGroupIDs[userGroups[i].ID] = struct{}{}
}
channels, err := h.channelService.ListAvailable(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]userAvailableChannel, 0, len(channels))
for _, ch := range channels {
if ch.Status != service.StatusActive {
continue
}
visibleGroups := filterUserVisibleGroups(ch.Groups, allowedGroupIDs)
if len(visibleGroups) == 0 {
continue
}
out = append(out, explodeChannelByPlatform(ch, visibleGroups)...)
}
response.Success(c, out)
}
// explodeChannelByPlatform 将单个渠道按 visibleGroups 的平台集合摊开成多条记录。
// 每条记录对应一个平台groups 仅含该平台的 visibleGroupssupported_models 仅含
// 该平台的模型。输出按 platform 字母序稳定排序,便于前端等效比较与回归测试。
func explodeChannelByPlatform(
ch service.AvailableChannel,
visibleGroups []userAvailableGroup,
) []userAvailableChannel {
groupsByPlatform := make(map[string][]userAvailableGroup, 4)
for _, g := range visibleGroups {
if g.Platform == "" {
continue
}
groupsByPlatform[g.Platform] = append(groupsByPlatform[g.Platform], g)
}
if len(groupsByPlatform) == 0 {
return nil
}
platforms := make([]string, 0, len(groupsByPlatform))
for p := range groupsByPlatform {
platforms = append(platforms, p)
}
sort.Strings(platforms)
out := make([]userAvailableChannel, 0, len(platforms))
for _, platform := range platforms {
platformSet := map[string]struct{}{platform: {}}
out = append(out, userAvailableChannel{
Name: ch.Name,
Description: ch.Description,
Platform: platform,
Groups: groupsByPlatform[platform],
SupportedModels: toUserSupportedModels(ch.SupportedModels, platformSet),
})
}
return out
}
// filterUserVisibleGroups 仅保留用户可访问的分组。
func filterUserVisibleGroups(
groups []service.AvailableGroupRef,
allowed map[int64]struct{},
) []userAvailableGroup {
visible := make([]userAvailableGroup, 0, len(groups))
for _, g := range groups {
if _, ok := allowed[g.ID]; !ok {
continue
}
visible = append(visible, userAvailableGroup{
ID: g.ID,
Name: g.Name,
Platform: g.Platform,
})
}
return visible
}
// toUserSupportedModels 将 service 层支持模型转换为用户 DTO字段白名单
// 仅保留平台在 allowedPlatforms 中的条目,防止跨平台模型信息泄漏。
// allowedPlatforms 为 nil 时不做平台过滤(保留全部,供测试或明确无过滤场景使用)。
func toUserSupportedModels(
src []service.SupportedModel,
allowedPlatforms map[string]struct{},
) []userSupportedModel {
out := make([]userSupportedModel, 0, len(src))
for i := range src {
m := src[i]
if allowedPlatforms != nil {
if _, ok := allowedPlatforms[m.Platform]; !ok {
continue
}
}
out = append(out, userSupportedModel{
Name: m.Name,
Platform: m.Platform,
Pricing: toUserPricing(m.Pricing),
})
}
return out
}
// toUserPricing 将 service 层定价转换为用户 DTO入参为 nil 时返回 nil。
func toUserPricing(p *service.ChannelModelPricing) *userSupportedModelPricing {
if p == nil {
return nil
}
intervals := make([]userPricingIntervalDTO, 0, len(p.Intervals))
for _, iv := range p.Intervals {
intervals = append(intervals, userPricingIntervalDTO{
MinTokens: iv.MinTokens,
MaxTokens: iv.MaxTokens,
TierLabel: iv.TierLabel,
InputPrice: iv.InputPrice,
OutputPrice: iv.OutputPrice,
CacheWritePrice: iv.CacheWritePrice,
CacheReadPrice: iv.CacheReadPrice,
PerRequestPrice: iv.PerRequestPrice,
})
}
billingMode := string(p.BillingMode)
if billingMode == "" {
billingMode = string(service.BillingModeToken)
}
return &userSupportedModelPricing{
BillingMode: billingMode,
InputPrice: p.InputPrice,
OutputPrice: p.OutputPrice,
CacheWritePrice: p.CacheWritePrice,
CacheReadPrice: p.CacheReadPrice,
ImageOutputPrice: p.ImageOutputPrice,
PerRequestPrice: p.PerRequestPrice,
Intervals: intervals,
}
}