feat(channel): 模型映射 + 分组搜索 + 卡片折叠 + 冲突校验

- 渠道模型映射:新增 model_mapping JSONB 字段,在账号映射之前执行
- 分组选择:添加搜索过滤 + 平台图标
- 定价卡片:支持折叠/展开,已有数据默认折叠
- 模型冲突校验:前后端均禁止同一渠道内重复模型
- 迁移 083: channels 表添加 model_mapping 列
This commit is contained in:
erio
2026-03-30 02:36:04 +08:00
parent dca0054e93
commit 29d58f2414
9 changed files with 405 additions and 171 deletions

View File

@@ -3,6 +3,7 @@ package repository
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
@@ -36,10 +37,14 @@ func (r *channelRepository) runInTx(ctx context.Context, fn func(tx *sql.Tx) err
func (r *channelRepository) Create(ctx context.Context, channel *service.Channel) error {
return r.runInTx(ctx, func(tx *sql.Tx) error {
err := tx.QueryRowContext(ctx,
`INSERT INTO channels (name, description, status) VALUES ($1, $2, $3)
modelMappingJSON, err := marshalModelMapping(channel.ModelMapping)
if err != nil {
return err
}
err = tx.QueryRowContext(ctx,
`INSERT INTO channels (name, description, status, model_mapping) VALUES ($1, $2, $3, $4)
RETURNING id, created_at, updated_at`,
channel.Name, channel.Description, channel.Status,
channel.Name, channel.Description, channel.Status, modelMappingJSON,
).Scan(&channel.ID, &channel.CreatedAt, &channel.UpdatedAt)
if err != nil {
if isUniqueViolation(err) {
@@ -68,16 +73,18 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel
func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Channel, error) {
ch := &service.Channel{}
var modelMappingJSON []byte
err := r.db.QueryRowContext(ctx,
`SELECT id, name, description, status, created_at, updated_at
`SELECT id, name, description, status, model_mapping, created_at, updated_at
FROM channels WHERE id = $1`, id,
).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &ch.CreatedAt, &ch.UpdatedAt)
).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.CreatedAt, &ch.UpdatedAt)
if err == sql.ErrNoRows {
return nil, service.ErrChannelNotFound
}
if err != nil {
return nil, fmt.Errorf("get channel: %w", err)
}
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
groupIDs, err := r.GetGroupIDs(ctx, id)
if err != nil {
@@ -96,10 +103,14 @@ func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Cha
func (r *channelRepository) Update(ctx context.Context, channel *service.Channel) error {
return r.runInTx(ctx, func(tx *sql.Tx) error {
modelMappingJSON, err := marshalModelMapping(channel.ModelMapping)
if err != nil {
return err
}
result, err := tx.ExecContext(ctx,
`UPDATE channels SET name = $1, description = $2, status = $3, updated_at = NOW()
WHERE id = $4`,
channel.Name, channel.Description, channel.Status, channel.ID,
`UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, updated_at = NOW()
WHERE id = $5`,
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.ID,
)
if err != nil {
if isUniqueViolation(err) {
@@ -176,7 +187,7 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
// 查询 channel 列表
dataQuery := fmt.Sprintf(
`SELECT c.id, c.name, c.description, c.status, c.created_at, c.updated_at
`SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.created_at, c.updated_at
FROM channels c WHERE %s ORDER BY c.id DESC LIMIT $%d OFFSET $%d`,
whereClause, argIdx, argIdx+1,
)
@@ -192,9 +203,11 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
var channelIDs []int64
for rows.Next() {
var ch service.Channel
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
var modelMappingJSON []byte
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
return nil, nil, fmt.Errorf("scan channel: %w", err)
}
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
channels = append(channels, ch)
channelIDs = append(channelIDs, ch.ID)
}
@@ -235,7 +248,7 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) {
rows, err := r.db.QueryContext(ctx,
`SELECT id, name, description, status, created_at, updated_at FROM channels ORDER BY id`,
`SELECT id, name, description, status, model_mapping, created_at, updated_at FROM channels ORDER BY id`,
)
if err != nil {
return nil, fmt.Errorf("query all channels: %w", err)
@@ -246,9 +259,11 @@ func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, err
var channelIDs []int64
for rows.Next() {
var ch service.Channel
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
var modelMappingJSON []byte
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
return nil, fmt.Errorf("scan channel: %w", err)
}
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
channels = append(channels, ch)
channelIDs = append(channelIDs, ch.ID)
}
@@ -390,3 +405,27 @@ func (r *channelRepository) GetGroupsInOtherChannels(ctx context.Context, channe
}
return conflicting, nil
}
// marshalModelMapping 将 model mapping 序列化为 JSON 字节nil/空 map 返回 '{}'
func marshalModelMapping(m map[string]string) ([]byte, error) {
if len(m) == 0 {
return []byte("{}"), nil
}
data, err := json.Marshal(m)
if err != nil {
return nil, fmt.Errorf("marshal model_mapping: %w", err)
}
return data, nil
}
// unmarshalModelMapping 将 JSON 字节反序列化为 model mapping
func unmarshalModelMapping(data []byte) map[string]string {
if len(data) == 0 {
return nil
}
var m map[string]string
if err := json.Unmarshal(data, &m); err != nil {
return nil
}
return m
}