feat(channel): 模型映射 + 分组搜索 + 卡片折叠 + 冲突校验

- 渠道模型映射:新增 model_mapping JSONB 字段,在账号映射之前执行
- 分组选择:添加搜索过滤 + 平台图标
- 定价卡片:支持折叠/展开,已有数据默认折叠
- 模型冲突校验:前后端均禁止同一渠道内重复模型
- 迁移 083: channels 表添加 model_mapping 列
This commit is contained in:
erio
2026-03-30 02:36:04 +08:00
parent dca0054e93
commit 29d58f2414
9 changed files with 405 additions and 171 deletions

View File

@@ -24,18 +24,20 @@ func NewChannelHandler(channelService *service.ChannelService) *ChannelHandler {
// --- Request / Response types ---
type createChannelRequest struct {
Name string `json:"name" binding:"required,max=100"`
Description string `json:"description"`
GroupIDs []int64 `json:"group_ids"`
ModelPricing []channelModelPricingRequest `json:"model_pricing"`
Name string `json:"name" binding:"required,max=100"`
Description string `json:"description"`
GroupIDs []int64 `json:"group_ids"`
ModelPricing []channelModelPricingRequest `json:"model_pricing"`
ModelMapping map[string]string `json:"model_mapping"`
}
type updateChannelRequest struct {
Name string `json:"name" binding:"omitempty,max=100"`
Description *string `json:"description"`
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
GroupIDs *[]int64 `json:"group_ids"`
ModelPricing *[]channelModelPricingRequest `json:"model_pricing"`
Name string `json:"name" binding:"omitempty,max=100"`
Description *string `json:"description"`
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
GroupIDs *[]int64 `json:"group_ids"`
ModelPricing *[]channelModelPricingRequest `json:"model_pricing"`
ModelMapping map[string]string `json:"model_mapping"`
}
type channelModelPricingRequest struct {
@@ -62,14 +64,15 @@ type pricingIntervalRequest struct {
}
type channelResponse struct {
ID int64 `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Status string `json:"status"`
GroupIDs []int64 `json:"group_ids"`
ModelPricing []channelModelPricingResponse `json:"model_pricing"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
ID int64 `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Status string `json:"status"`
GroupIDs []int64 `json:"group_ids"`
ModelPricing []channelModelPricingResponse `json:"model_pricing"`
ModelMapping map[string]string `json:"model_mapping"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
}
type channelModelPricingResponse struct {
@@ -106,13 +109,17 @@ func channelToResponse(ch *service.Channel) *channelResponse {
Name: ch.Name,
Description: ch.Description,
Status: ch.Status,
GroupIDs: ch.GroupIDs,
GroupIDs: ch.GroupIDs,
ModelMapping: ch.ModelMapping,
CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"),
UpdatedAt: ch.UpdatedAt.Format("2006-01-02T15:04:05Z"),
}
if resp.GroupIDs == nil {
resp.GroupIDs = []int64{}
}
if resp.ModelMapping == nil {
resp.ModelMapping = map[string]string{}
}
resp.ModelPricing = make([]channelModelPricingResponse, 0, len(ch.ModelPricing))
for _, p := range ch.ModelPricing {
@@ -246,6 +253,7 @@ func (h *ChannelHandler) Create(c *gin.Context) {
Description: req.Description,
GroupIDs: req.GroupIDs,
ModelPricing: pricingRequestToService(req.ModelPricing),
ModelMapping: req.ModelMapping,
})
if err != nil {
response.ErrorFrom(c, err)
@@ -271,10 +279,11 @@ func (h *ChannelHandler) Update(c *gin.Context) {
}
input := &service.UpdateChannelInput{
Name: req.Name,
Description: req.Description,
Status: req.Status,
GroupIDs: req.GroupIDs,
Name: req.Name,
Description: req.Description,
Status: req.Status,
GroupIDs: req.GroupIDs,
ModelMapping: req.ModelMapping,
}
if req.ModelPricing != nil {
pricing := pricingRequestToService(*req.ModelPricing)

View File

@@ -3,6 +3,7 @@ package repository
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
@@ -36,10 +37,14 @@ func (r *channelRepository) runInTx(ctx context.Context, fn func(tx *sql.Tx) err
func (r *channelRepository) Create(ctx context.Context, channel *service.Channel) error {
return r.runInTx(ctx, func(tx *sql.Tx) error {
err := tx.QueryRowContext(ctx,
`INSERT INTO channels (name, description, status) VALUES ($1, $2, $3)
modelMappingJSON, err := marshalModelMapping(channel.ModelMapping)
if err != nil {
return err
}
err = tx.QueryRowContext(ctx,
`INSERT INTO channels (name, description, status, model_mapping) VALUES ($1, $2, $3, $4)
RETURNING id, created_at, updated_at`,
channel.Name, channel.Description, channel.Status,
channel.Name, channel.Description, channel.Status, modelMappingJSON,
).Scan(&channel.ID, &channel.CreatedAt, &channel.UpdatedAt)
if err != nil {
if isUniqueViolation(err) {
@@ -68,16 +73,18 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel
func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Channel, error) {
ch := &service.Channel{}
var modelMappingJSON []byte
err := r.db.QueryRowContext(ctx,
`SELECT id, name, description, status, created_at, updated_at
`SELECT id, name, description, status, model_mapping, created_at, updated_at
FROM channels WHERE id = $1`, id,
).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &ch.CreatedAt, &ch.UpdatedAt)
).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.CreatedAt, &ch.UpdatedAt)
if err == sql.ErrNoRows {
return nil, service.ErrChannelNotFound
}
if err != nil {
return nil, fmt.Errorf("get channel: %w", err)
}
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
groupIDs, err := r.GetGroupIDs(ctx, id)
if err != nil {
@@ -96,10 +103,14 @@ func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Cha
func (r *channelRepository) Update(ctx context.Context, channel *service.Channel) error {
return r.runInTx(ctx, func(tx *sql.Tx) error {
modelMappingJSON, err := marshalModelMapping(channel.ModelMapping)
if err != nil {
return err
}
result, err := tx.ExecContext(ctx,
`UPDATE channels SET name = $1, description = $2, status = $3, updated_at = NOW()
WHERE id = $4`,
channel.Name, channel.Description, channel.Status, channel.ID,
`UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, updated_at = NOW()
WHERE id = $5`,
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.ID,
)
if err != nil {
if isUniqueViolation(err) {
@@ -176,7 +187,7 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
// 查询 channel 列表
dataQuery := fmt.Sprintf(
`SELECT c.id, c.name, c.description, c.status, c.created_at, c.updated_at
`SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.created_at, c.updated_at
FROM channels c WHERE %s ORDER BY c.id DESC LIMIT $%d OFFSET $%d`,
whereClause, argIdx, argIdx+1,
)
@@ -192,9 +203,11 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
var channelIDs []int64
for rows.Next() {
var ch service.Channel
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
var modelMappingJSON []byte
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
return nil, nil, fmt.Errorf("scan channel: %w", err)
}
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
channels = append(channels, ch)
channelIDs = append(channelIDs, ch.ID)
}
@@ -235,7 +248,7 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) {
rows, err := r.db.QueryContext(ctx,
`SELECT id, name, description, status, created_at, updated_at FROM channels ORDER BY id`,
`SELECT id, name, description, status, model_mapping, created_at, updated_at FROM channels ORDER BY id`,
)
if err != nil {
return nil, fmt.Errorf("query all channels: %w", err)
@@ -246,9 +259,11 @@ func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, err
var channelIDs []int64
for rows.Next() {
var ch service.Channel
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
var modelMappingJSON []byte
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
return nil, fmt.Errorf("scan channel: %w", err)
}
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
channels = append(channels, ch)
channelIDs = append(channelIDs, ch.ID)
}
@@ -390,3 +405,27 @@ func (r *channelRepository) GetGroupsInOtherChannels(ctx context.Context, channe
}
return conflicting, nil
}
// marshalModelMapping 将 model mapping 序列化为 JSON 字节nil/空 map 返回 '{}'
func marshalModelMapping(m map[string]string) ([]byte, error) {
if len(m) == 0 {
return []byte("{}"), nil
}
data, err := json.Marshal(m)
if err != nil {
return nil, fmt.Errorf("marshal model_mapping: %w", err)
}
return data, nil
}
// unmarshalModelMapping 将 JSON 字节反序列化为 model mapping
func unmarshalModelMapping(data []byte) map[string]string {
if len(data) == 0 {
return nil
}
var m map[string]string
if err := json.Unmarshal(data, &m); err != nil {
return nil
}
return m
}

View File

@@ -36,6 +36,8 @@ type Channel struct {
GroupIDs []int64
// 模型定价列表
ModelPricing []ChannelModelPricing
// 渠道级模型映射
ModelMapping map[string]string
}
// ChannelModelPricing 渠道模型定价条目
@@ -71,6 +73,33 @@ type PricingInterval struct {
UpdatedAt time.Time
}
// ResolveMappedModel 解析渠道级模型映射,返回映射后的模型名。
// 支持通配符(如 "claude-*" → "claude-sonnet-4")。
// 如果没有匹配的映射规则,返回原始模型名。
func (c *Channel) ResolveMappedModel(requestedModel string) string {
if len(c.ModelMapping) == 0 {
return requestedModel
}
lower := strings.ToLower(requestedModel)
// 精确匹配优先
for src, dst := range c.ModelMapping {
if strings.ToLower(src) == lower {
return dst
}
}
// 通配符匹配
for src, dst := range c.ModelMapping {
srcLower := strings.ToLower(src)
if strings.HasSuffix(srcLower, "*") {
prefix := strings.TrimSuffix(srcLower, "*")
if strings.HasPrefix(lower, prefix) {
return dst
}
}
}
return requestedModel
}
// IsActive 判断渠道是否启用
func (c *Channel) IsActive() bool {
return c.Status == StatusActive
@@ -168,5 +197,11 @@ func (c *Channel) Clone() *Channel {
cp.ModelPricing[i] = c.ModelPricing[i].Clone()
}
}
if c.ModelMapping != nil {
cp.ModelMapping = make(map[string]string, len(c.ModelMapping))
for k, v := range c.ModelMapping {
cp.ModelMapping[k] = v
}
}
return &cp
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"log/slog"
"strings"
"sync/atomic"
"time"
@@ -213,6 +214,11 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
Status: StatusActive,
GroupIDs: input.GroupIDs,
ModelPricing: input.ModelPricing,
ModelMapping: input.ModelMapping,
}
if err := validateNoDuplicateModels(channel.ModelPricing); err != nil {
return nil, err
}
if err := s.repo.Create(ctx, channel); err != nil {
@@ -270,6 +276,14 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan
channel.ModelPricing = *input.ModelPricing
}
if input.ModelMapping != nil {
channel.ModelMapping = input.ModelMapping
}
if err := validateNoDuplicateModels(channel.ModelPricing); err != nil {
return nil, err
}
if err := s.repo.Update(ctx, channel); err != nil {
return nil, fmt.Errorf("update channel: %w", err)
}
@@ -318,6 +332,21 @@ func (s *ChannelService) List(ctx context.Context, params pagination.PaginationP
return s.repo.List(ctx, params, status, search)
}
// validateNoDuplicateModels 检查定价列表中是否有重复模型
func validateNoDuplicateModels(pricingList []ChannelModelPricing) error {
seen := make(map[string]bool)
for _, p := range pricingList {
for _, model := range p.Models {
lower := strings.ToLower(model)
if seen[lower] {
return infraerrors.BadRequest("DUPLICATE_MODEL", fmt.Sprintf("model '%s' appears in multiple pricing entries", model))
}
seen[lower] = true
}
}
return nil
}
// --- Input types ---
// CreateChannelInput 创建渠道输入
@@ -326,6 +355,7 @@ type CreateChannelInput struct {
Description string
GroupIDs []int64
ModelPricing []ChannelModelPricing
ModelMapping map[string]string
}
// UpdateChannelInput 更新渠道输入
@@ -335,4 +365,5 @@ type UpdateChannelInput struct {
Status string
GroupIDs *[]int64
ModelPricing *[]ChannelModelPricing
ModelMapping map[string]string
}