diff --git a/backend/internal/handler/admin/channel_handler.go b/backend/internal/handler/admin/channel_handler.go index 492b6b8f..f47d6791 100644 --- a/backend/internal/handler/admin/channel_handler.go +++ b/backend/internal/handler/admin/channel_handler.go @@ -24,18 +24,20 @@ func NewChannelHandler(channelService *service.ChannelService) *ChannelHandler { // --- Request / Response types --- type createChannelRequest struct { - Name string `json:"name" binding:"required,max=100"` - Description string `json:"description"` - GroupIDs []int64 `json:"group_ids"` - ModelPricing []channelModelPricingRequest `json:"model_pricing"` + Name string `json:"name" binding:"required,max=100"` + Description string `json:"description"` + GroupIDs []int64 `json:"group_ids"` + ModelPricing []channelModelPricingRequest `json:"model_pricing"` + ModelMapping map[string]string `json:"model_mapping"` } type updateChannelRequest struct { - Name string `json:"name" binding:"omitempty,max=100"` - Description *string `json:"description"` - Status string `json:"status" binding:"omitempty,oneof=active disabled"` - GroupIDs *[]int64 `json:"group_ids"` - ModelPricing *[]channelModelPricingRequest `json:"model_pricing"` + Name string `json:"name" binding:"omitempty,max=100"` + Description *string `json:"description"` + Status string `json:"status" binding:"omitempty,oneof=active disabled"` + GroupIDs *[]int64 `json:"group_ids"` + ModelPricing *[]channelModelPricingRequest `json:"model_pricing"` + ModelMapping map[string]string `json:"model_mapping"` } type channelModelPricingRequest struct { @@ -62,14 +64,15 @@ type pricingIntervalRequest struct { } type channelResponse struct { - ID int64 `json:"id"` - Name string `json:"name"` - Description string `json:"description"` - Status string `json:"status"` - GroupIDs []int64 `json:"group_ids"` - ModelPricing []channelModelPricingResponse `json:"model_pricing"` - CreatedAt string `json:"created_at"` - UpdatedAt string `json:"updated_at"` + ID int64 `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + Status string `json:"status"` + GroupIDs []int64 `json:"group_ids"` + ModelPricing []channelModelPricingResponse `json:"model_pricing"` + ModelMapping map[string]string `json:"model_mapping"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` } type channelModelPricingResponse struct { @@ -106,13 +109,17 @@ func channelToResponse(ch *service.Channel) *channelResponse { Name: ch.Name, Description: ch.Description, Status: ch.Status, - GroupIDs: ch.GroupIDs, + GroupIDs: ch.GroupIDs, + ModelMapping: ch.ModelMapping, CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"), UpdatedAt: ch.UpdatedAt.Format("2006-01-02T15:04:05Z"), } if resp.GroupIDs == nil { resp.GroupIDs = []int64{} } + if resp.ModelMapping == nil { + resp.ModelMapping = map[string]string{} + } resp.ModelPricing = make([]channelModelPricingResponse, 0, len(ch.ModelPricing)) for _, p := range ch.ModelPricing { @@ -246,6 +253,7 @@ func (h *ChannelHandler) Create(c *gin.Context) { Description: req.Description, GroupIDs: req.GroupIDs, ModelPricing: pricingRequestToService(req.ModelPricing), + ModelMapping: req.ModelMapping, }) if err != nil { response.ErrorFrom(c, err) @@ -271,10 +279,11 @@ func (h *ChannelHandler) Update(c *gin.Context) { } input := &service.UpdateChannelInput{ - Name: req.Name, - Description: req.Description, - Status: req.Status, - GroupIDs: req.GroupIDs, + Name: req.Name, + Description: req.Description, + Status: req.Status, + GroupIDs: req.GroupIDs, + ModelMapping: req.ModelMapping, } if req.ModelPricing != nil { pricing := pricingRequestToService(*req.ModelPricing) diff --git a/backend/internal/repository/channel_repo.go b/backend/internal/repository/channel_repo.go index 9259edd6..eaf25668 100644 --- a/backend/internal/repository/channel_repo.go +++ b/backend/internal/repository/channel_repo.go @@ -3,6 +3,7 @@ package repository import ( "context" "database/sql" + "encoding/json" "fmt" "strings" @@ -36,10 +37,14 @@ func (r *channelRepository) runInTx(ctx context.Context, fn func(tx *sql.Tx) err func (r *channelRepository) Create(ctx context.Context, channel *service.Channel) error { return r.runInTx(ctx, func(tx *sql.Tx) error { - err := tx.QueryRowContext(ctx, - `INSERT INTO channels (name, description, status) VALUES ($1, $2, $3) + modelMappingJSON, err := marshalModelMapping(channel.ModelMapping) + if err != nil { + return err + } + err = tx.QueryRowContext(ctx, + `INSERT INTO channels (name, description, status, model_mapping) VALUES ($1, $2, $3, $4) RETURNING id, created_at, updated_at`, - channel.Name, channel.Description, channel.Status, + channel.Name, channel.Description, channel.Status, modelMappingJSON, ).Scan(&channel.ID, &channel.CreatedAt, &channel.UpdatedAt) if err != nil { if isUniqueViolation(err) { @@ -68,16 +73,18 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Channel, error) { ch := &service.Channel{} + var modelMappingJSON []byte err := r.db.QueryRowContext(ctx, - `SELECT id, name, description, status, created_at, updated_at + `SELECT id, name, description, status, model_mapping, created_at, updated_at FROM channels WHERE id = $1`, id, - ).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &ch.CreatedAt, &ch.UpdatedAt) + ).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.CreatedAt, &ch.UpdatedAt) if err == sql.ErrNoRows { return nil, service.ErrChannelNotFound } if err != nil { return nil, fmt.Errorf("get channel: %w", err) } + ch.ModelMapping = unmarshalModelMapping(modelMappingJSON) groupIDs, err := r.GetGroupIDs(ctx, id) if err != nil { @@ -96,10 +103,14 @@ func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Cha func (r *channelRepository) Update(ctx context.Context, channel *service.Channel) error { return r.runInTx(ctx, func(tx *sql.Tx) error { + modelMappingJSON, err := marshalModelMapping(channel.ModelMapping) + if err != nil { + return err + } result, err := tx.ExecContext(ctx, - `UPDATE channels SET name = $1, description = $2, status = $3, updated_at = NOW() - WHERE id = $4`, - channel.Name, channel.Description, channel.Status, channel.ID, + `UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, updated_at = NOW() + WHERE id = $5`, + channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.ID, ) if err != nil { if isUniqueViolation(err) { @@ -176,7 +187,7 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati // 查询 channel 列表 dataQuery := fmt.Sprintf( - `SELECT c.id, c.name, c.description, c.status, c.created_at, c.updated_at + `SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.created_at, c.updated_at FROM channels c WHERE %s ORDER BY c.id DESC LIMIT $%d OFFSET $%d`, whereClause, argIdx, argIdx+1, ) @@ -192,9 +203,11 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati var channelIDs []int64 for rows.Next() { var ch service.Channel - if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &ch.CreatedAt, &ch.UpdatedAt); err != nil { + var modelMappingJSON []byte + if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.CreatedAt, &ch.UpdatedAt); err != nil { return nil, nil, fmt.Errorf("scan channel: %w", err) } + ch.ModelMapping = unmarshalModelMapping(modelMappingJSON) channels = append(channels, ch) channelIDs = append(channelIDs, ch.ID) } @@ -235,7 +248,7 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) { rows, err := r.db.QueryContext(ctx, - `SELECT id, name, description, status, created_at, updated_at FROM channels ORDER BY id`, + `SELECT id, name, description, status, model_mapping, created_at, updated_at FROM channels ORDER BY id`, ) if err != nil { return nil, fmt.Errorf("query all channels: %w", err) @@ -246,9 +259,11 @@ func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, err var channelIDs []int64 for rows.Next() { var ch service.Channel - if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &ch.CreatedAt, &ch.UpdatedAt); err != nil { + var modelMappingJSON []byte + if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.CreatedAt, &ch.UpdatedAt); err != nil { return nil, fmt.Errorf("scan channel: %w", err) } + ch.ModelMapping = unmarshalModelMapping(modelMappingJSON) channels = append(channels, ch) channelIDs = append(channelIDs, ch.ID) } @@ -390,3 +405,27 @@ func (r *channelRepository) GetGroupsInOtherChannels(ctx context.Context, channe } return conflicting, nil } + +// marshalModelMapping 将 model mapping 序列化为 JSON 字节,nil/空 map 返回 '{}' +func marshalModelMapping(m map[string]string) ([]byte, error) { + if len(m) == 0 { + return []byte("{}"), nil + } + data, err := json.Marshal(m) + if err != nil { + return nil, fmt.Errorf("marshal model_mapping: %w", err) + } + return data, nil +} + +// unmarshalModelMapping 将 JSON 字节反序列化为 model mapping +func unmarshalModelMapping(data []byte) map[string]string { + if len(data) == 0 { + return nil + } + var m map[string]string + if err := json.Unmarshal(data, &m); err != nil { + return nil + } + return m +} diff --git a/backend/internal/service/channel.go b/backend/internal/service/channel.go index be82b997..7b43b18b 100644 --- a/backend/internal/service/channel.go +++ b/backend/internal/service/channel.go @@ -36,6 +36,8 @@ type Channel struct { GroupIDs []int64 // 模型定价列表 ModelPricing []ChannelModelPricing + // 渠道级模型映射 + ModelMapping map[string]string } // ChannelModelPricing 渠道模型定价条目 @@ -71,6 +73,33 @@ type PricingInterval struct { UpdatedAt time.Time } +// ResolveMappedModel 解析渠道级模型映射,返回映射后的模型名。 +// 支持通配符(如 "claude-*" → "claude-sonnet-4")。 +// 如果没有匹配的映射规则,返回原始模型名。 +func (c *Channel) ResolveMappedModel(requestedModel string) string { + if len(c.ModelMapping) == 0 { + return requestedModel + } + lower := strings.ToLower(requestedModel) + // 精确匹配优先 + for src, dst := range c.ModelMapping { + if strings.ToLower(src) == lower { + return dst + } + } + // 通配符匹配 + for src, dst := range c.ModelMapping { + srcLower := strings.ToLower(src) + if strings.HasSuffix(srcLower, "*") { + prefix := strings.TrimSuffix(srcLower, "*") + if strings.HasPrefix(lower, prefix) { + return dst + } + } + } + return requestedModel +} + // IsActive 判断渠道是否启用 func (c *Channel) IsActive() bool { return c.Status == StatusActive @@ -168,5 +197,11 @@ func (c *Channel) Clone() *Channel { cp.ModelPricing[i] = c.ModelPricing[i].Clone() } } + if c.ModelMapping != nil { + cp.ModelMapping = make(map[string]string, len(c.ModelMapping)) + for k, v := range c.ModelMapping { + cp.ModelMapping[k] = v + } + } return &cp } diff --git a/backend/internal/service/channel_service.go b/backend/internal/service/channel_service.go index 8f00481f..adf1a64f 100644 --- a/backend/internal/service/channel_service.go +++ b/backend/internal/service/channel_service.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "log/slog" + "strings" "sync/atomic" "time" @@ -213,6 +214,11 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput) Status: StatusActive, GroupIDs: input.GroupIDs, ModelPricing: input.ModelPricing, + ModelMapping: input.ModelMapping, + } + + if err := validateNoDuplicateModels(channel.ModelPricing); err != nil { + return nil, err } if err := s.repo.Create(ctx, channel); err != nil { @@ -270,6 +276,14 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan channel.ModelPricing = *input.ModelPricing } + if input.ModelMapping != nil { + channel.ModelMapping = input.ModelMapping + } + + if err := validateNoDuplicateModels(channel.ModelPricing); err != nil { + return nil, err + } + if err := s.repo.Update(ctx, channel); err != nil { return nil, fmt.Errorf("update channel: %w", err) } @@ -318,6 +332,21 @@ func (s *ChannelService) List(ctx context.Context, params pagination.PaginationP return s.repo.List(ctx, params, status, search) } +// validateNoDuplicateModels 检查定价列表中是否有重复模型 +func validateNoDuplicateModels(pricingList []ChannelModelPricing) error { + seen := make(map[string]bool) + for _, p := range pricingList { + for _, model := range p.Models { + lower := strings.ToLower(model) + if seen[lower] { + return infraerrors.BadRequest("DUPLICATE_MODEL", fmt.Sprintf("model '%s' appears in multiple pricing entries", model)) + } + seen[lower] = true + } + } + return nil +} + // --- Input types --- // CreateChannelInput 创建渠道输入 @@ -326,6 +355,7 @@ type CreateChannelInput struct { Description string GroupIDs []int64 ModelPricing []ChannelModelPricing + ModelMapping map[string]string } // UpdateChannelInput 更新渠道输入 @@ -335,4 +365,5 @@ type UpdateChannelInput struct { Status string GroupIDs *[]int64 ModelPricing *[]ChannelModelPricing + ModelMapping map[string]string } diff --git a/backend/migrations/083_channel_model_mapping.sql b/backend/migrations/083_channel_model_mapping.sql new file mode 100644 index 00000000..68e2203f --- /dev/null +++ b/backend/migrations/083_channel_model_mapping.sql @@ -0,0 +1,5 @@ +SET LOCAL lock_timeout = '5s'; +SET LOCAL statement_timeout = '10min'; + +ALTER TABLE channels ADD COLUMN IF NOT EXISTS model_mapping JSONB DEFAULT '{}'; +COMMENT ON COLUMN channels.model_mapping IS '渠道级模型映射,在账号映射之前执行。格式:{"source_model": "target_model"}'; diff --git a/frontend/src/components/admin/channel/PricingEntryCard.vue b/frontend/src/components/admin/channel/PricingEntryCard.vue index 84eca5ee..b0238a19 100644 --- a/frontend/src/components/admin/channel/PricingEntryCard.vue +++ b/frontend/src/components/admin/channel/PricingEntryCard.vue @@ -1,148 +1,209 @@ + + diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index f738b277..679f8290 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -1743,6 +1743,7 @@ export default { updateError: 'Failed to update channel', deleteError: 'Failed to delete channel', nameRequired: 'Please enter a channel name', + duplicateModels: 'Model "{0}" appears in multiple pricing entries', deleteConfirm: 'Are you sure you want to delete channel "{name}"? This cannot be undone.', columns: { name: 'Name', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 34469e9c..dfb859a5 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -1823,6 +1823,7 @@ export default { updateError: '更新渠道失败', deleteError: '删除渠道失败', nameRequired: '请输入渠道名称', + duplicateModels: '模型「{0}」在多个定价条目中重复', deleteConfirm: '确定要删除渠道「{name}」吗?此操作不可撤销。', columns: { name: '名称', diff --git a/frontend/src/views/admin/ChannelsView.vue b/frontend/src/views/admin/ChannelsView.vue index df4cac91..9af183fa 100644 --- a/frontend/src/views/admin/ChannelsView.vue +++ b/frontend/src/views/admin/ChannelsView.vue @@ -176,6 +176,19 @@
+
+ + +
@@ -185,8 +198,11 @@
{{ t('admin.channels.form.noGroupsAvailable', 'No groups available') }}
+
+ {{ t('admin.channels.form.noGroupsMatch', 'No groups match your search') }} +
@@ -299,6 +310,7 @@ import ConfirmDialog from '@/components/common/ConfirmDialog.vue' import EmptyState from '@/components/common/EmptyState.vue' import Select from '@/components/common/Select.vue' import Icon from '@/components/icons/Icon.vue' +import PlatformIcon from '@/components/common/PlatformIcon.vue' import PricingEntryCard from '@/components/admin/channel/PricingEntryCard.vue' import { getPersistedPageSize } from '@/composables/usePersistedPageSize' @@ -348,6 +360,7 @@ const deletingChannel = ref(null) // Groups const allGroups = ref([]) const groupsLoading = ref(false) +const groupSearchQuery = ref('') // Form data const form = reactive({ @@ -367,6 +380,12 @@ function formatDate(value: string): string { } // ── Group helpers ── +const filteredGroups = computed(() => { + const query = groupSearchQuery.value.trim().toLowerCase() + if (!query) return allGroups.value + return allGroups.value.filter(g => g.name.toLowerCase().includes(query)) +}) + const groupToChannelMap = computed(() => { const map = new Map() for (const ch of channels.value) { @@ -525,6 +544,7 @@ function resetForm() { form.status = 'active' form.group_ids = [] form.model_pricing = [] + groupSearchQuery.value = '' } function openCreateDialog() { @@ -558,6 +578,14 @@ async function handleSubmit() { return } + // 检查模型重复 + const allModels = form.model_pricing.flatMap(e => e.models.map(m => m.toLowerCase())) + const duplicates = allModels.filter((m, i) => allModels.indexOf(m) !== i) + if (duplicates.length > 0) { + appStore.showError(t('admin.channels.duplicateModels', `模型 "${duplicates[0]}" 在多个定价条目中重复`)) + return + } + submitting.value = true try { if (editingChannel.value) {