diff --git a/controller/model_meta.go b/controller/model_meta.go index 9b9f7849..31ea64f3 100644 --- a/controller/model_meta.go +++ b/controller/model_meta.go @@ -2,6 +2,7 @@ package controller import ( "encoding/json" + "sort" "strconv" "strings" @@ -21,10 +22,8 @@ func GetAllModelsMeta(c *gin.Context) { common.ApiError(c, err) return } - // 填充附加字段 - for _, m := range modelsMeta { - fillModelExtra(m) - } + // 批量填充附加字段,提升列表接口性能 + enrichModels(modelsMeta) var total int64 model.DB.Model(&model.Model{}).Count(&total) @@ -54,9 +53,8 @@ func SearchModelsMeta(c *gin.Context) { common.ApiError(c, err) return } - for _, m := range modelsMeta { - fillModelExtra(m) - } + // 批量填充附加字段,提升列表接口性能 + enrichModels(modelsMeta) pageInfo.SetTotal(int(total)) pageInfo.SetItems(modelsMeta) common.ApiSuccess(c, pageInfo) @@ -75,7 +73,7 @@ func GetModelMeta(c *gin.Context) { common.ApiError(c, err) return } - fillModelExtra(&m) + enrichModels([]*model.Model{&m}) common.ApiSuccess(c, &m) } @@ -162,104 +160,157 @@ func DeleteModelMeta(c *gin.Context) { common.ApiSuccess(c, nil) } -// 辅助函数:填充 Endpoints 和 BoundChannels 和 EnableGroups -func fillModelExtra(m *model.Model) { - // 若为精确匹配,保持原有逻辑 - if m.NameRule == model.NameRuleExact { - if m.Endpoints == "" { - eps := model.GetModelSupportEndpointTypes(m.ModelName) - if b, err := json.Marshal(eps); err == nil { - m.Endpoints = string(b) - } - } - if channels, err := model.GetBoundChannels(m.ModelName); err == nil { - m.BoundChannels = channels - } - m.EnableGroups = model.GetModelEnableGroups(m.ModelName) - m.QuotaType = model.GetModelQuotaType(m.ModelName) +// enrichModels 批量填充附加信息:端点、渠道、分组、计费类型,避免 N+1 查询 +func enrichModels(models []*model.Model) { + if len(models) == 0 { return } - // 非精确匹配:计算并集 - pricings := model.GetPricing() - - // 匹配到的模型名称集合 - matchedNames := make([]string, 0) - - // 端点去重集合 - endpointSet := make(map[constant.EndpointType]struct{}) - - // 已绑定渠道去重集合 - channelSet := make(map[string]model.BoundChannel) - // 分组去重集合 - groupSet := make(map[string]struct{}) - // 计费类型(若有任意模型为 1,则返回 1) - quotaTypeSet := make(map[int]struct{}) - - for _, p := range pricings { - var matched bool - switch m.NameRule { - case model.NameRulePrefix: - matched = strings.HasPrefix(p.ModelName, m.ModelName) - case model.NameRuleSuffix: - matched = strings.HasSuffix(p.ModelName, m.ModelName) - case model.NameRuleContains: - matched = strings.Contains(p.ModelName, m.ModelName) - } - if !matched { + // 1) 拆分精确与规则匹配 + exactNames := make([]string, 0) + exactIdx := make(map[string][]int) // modelName -> indices in models + ruleIndices := make([]int, 0) + for i, m := range models { + if m == nil { continue } - - // 记录匹配到的模型名称 - matchedNames = append(matchedNames, p.ModelName) - - // 收集端点 - for _, et := range p.SupportedEndpointTypes { - endpointSet[et] = struct{}{} - } - - // 收集分组 - for _, g := range p.EnableGroup { - groupSet[g] = struct{}{} - } - - // 收集计费类型 - quotaTypeSet[p.QuotaType] = struct{}{} - } - - // 序列化端点 - if len(endpointSet) > 0 && m.Endpoints == "" { - eps := make([]constant.EndpointType, 0, len(endpointSet)) - for et := range endpointSet { - eps = append(eps, et) - } - if b, err := json.Marshal(eps); err == nil { - m.Endpoints = string(b) + if m.NameRule == model.NameRuleExact { + exactNames = append(exactNames, m.ModelName) + exactIdx[m.ModelName] = append(exactIdx[m.ModelName], i) + } else { + ruleIndices = append(ruleIndices, i) } } - // 序列化分组 - if len(groupSet) > 0 { - groups := make([]string, 0, len(groupSet)) - for g := range groupSet { - groups = append(groups, g) + // 2) 批量查询精确模型的绑定渠道 + channelsByModel, _ := model.GetBoundChannelsByModelsMap(exactNames) + + // 3) 精确模型:端点从缓存、渠道批量映射、分组/计费类型从缓存 + for name, indices := range exactIdx { + chs := channelsByModel[name] + for _, idx := range indices { + mm := models[idx] + if mm.Endpoints == "" { + eps := model.GetModelSupportEndpointTypes(mm.ModelName) + if b, err := json.Marshal(eps); err == nil { + mm.Endpoints = string(b) + } + } + mm.BoundChannels = chs + mm.EnableGroups = model.GetModelEnableGroups(mm.ModelName) + mm.QuotaTypes = model.GetModelQuotaTypes(mm.ModelName) } - m.EnableGroups = groups } - // 确定计费类型:仅当所有匹配模型计费类型一致时才返回该类型,否则返回 -1 表示未知/不确定 - if len(quotaTypeSet) == 1 { - for k := range quotaTypeSet { - m.QuotaType = k - } - } else { - m.QuotaType = -1 + if len(ruleIndices) == 0 { + return } - // 批量查询并序列化渠道 - if len(matchedNames) > 0 { - if channels, err := model.GetBoundChannelsForModels(matchedNames); err == nil { - for _, ch := range channels { + // 4) 一次性读取定价缓存,内存匹配所有规则模型 + pricings := model.GetPricing() + + // 为全部规则模型收集匹配名集合、端点并集、分组并集、配额集合 + matchedNamesByIdx := make(map[int][]string) + endpointSetByIdx := make(map[int]map[constant.EndpointType]struct{}) + groupSetByIdx := make(map[int]map[string]struct{}) + quotaSetByIdx := make(map[int]map[int]struct{}) + + for _, p := range pricings { + for _, idx := range ruleIndices { + mm := models[idx] + var matched bool + switch mm.NameRule { + case model.NameRulePrefix: + matched = strings.HasPrefix(p.ModelName, mm.ModelName) + case model.NameRuleSuffix: + matched = strings.HasSuffix(p.ModelName, mm.ModelName) + case model.NameRuleContains: + matched = strings.Contains(p.ModelName, mm.ModelName) + } + if !matched { + continue + } + matchedNamesByIdx[idx] = append(matchedNamesByIdx[idx], p.ModelName) + + es := endpointSetByIdx[idx] + if es == nil { + es = make(map[constant.EndpointType]struct{}) + endpointSetByIdx[idx] = es + } + for _, et := range p.SupportedEndpointTypes { + es[et] = struct{}{} + } + + gs := groupSetByIdx[idx] + if gs == nil { + gs = make(map[string]struct{}) + groupSetByIdx[idx] = gs + } + for _, g := range p.EnableGroup { + gs[g] = struct{}{} + } + + qs := quotaSetByIdx[idx] + if qs == nil { + qs = make(map[int]struct{}) + quotaSetByIdx[idx] = qs + } + qs[p.QuotaType] = struct{}{} + } + } + + // 5) 汇总所有匹配到的模型名称,批量查询一次渠道 + allMatchedSet := make(map[string]struct{}) + for _, names := range matchedNamesByIdx { + for _, n := range names { + allMatchedSet[n] = struct{}{} + } + } + allMatched := make([]string, 0, len(allMatchedSet)) + for n := range allMatchedSet { + allMatched = append(allMatched, n) + } + matchedChannelsByModel, _ := model.GetBoundChannelsByModelsMap(allMatched) + + // 6) 回填每个规则模型的并集信息 + for _, idx := range ruleIndices { + mm := models[idx] + + // 端点并集 -> 序列化 + if es, ok := endpointSetByIdx[idx]; ok && mm.Endpoints == "" { + eps := make([]constant.EndpointType, 0, len(es)) + for et := range es { + eps = append(eps, et) + } + if b, err := json.Marshal(eps); err == nil { + mm.Endpoints = string(b) + } + } + + // 分组并集 + if gs, ok := groupSetByIdx[idx]; ok { + groups := make([]string, 0, len(gs)) + for g := range gs { + groups = append(groups, g) + } + mm.EnableGroups = groups + } + + // 配额类型集合(保持去重并排序) + if qs, ok := quotaSetByIdx[idx]; ok { + arr := make([]int, 0, len(qs)) + for k := range qs { + arr = append(arr, k) + } + sort.Ints(arr) + mm.QuotaTypes = arr + } + + // 渠道并集 + names := matchedNamesByIdx[idx] + channelSet := make(map[string]model.BoundChannel) + for _, n := range names { + for _, ch := range matchedChannelsByModel[n] { key := ch.Name + "_" + strconv.Itoa(ch.Type) channelSet[key] = ch } @@ -269,11 +320,11 @@ func fillModelExtra(m *model.Model) { for _, ch := range channelSet { chs = append(chs, ch) } - m.BoundChannels = chs + mm.BoundChannels = chs } - } - // 设置匹配信息 - m.MatchedModels = matchedNames - m.MatchedCount = len(matchedNames) + // 匹配信息 + mm.MatchedModels = names + mm.MatchedCount = len(names) + } } diff --git a/model/model_extra.go b/model/model_extra.go index 495b5bb9..71fd84e7 100644 --- a/model/model_extra.go +++ b/model/model_extra.go @@ -1,7 +1,5 @@ package model -// GetModelEnableGroups 返回指定模型名称可用的用户分组列表。 -// 使用在 updatePricing() 中维护的缓存映射,O(1) 读取,适合高并发场景。 func GetModelEnableGroups(modelName string) []string { // 确保缓存最新 GetPricing() @@ -19,16 +17,15 @@ func GetModelEnableGroups(modelName string) []string { return groups } -// GetModelQuotaType 返回指定模型的计费类型(quota_type)。 -// 同样使用缓存映射,避免每次遍历定价切片。 -func GetModelQuotaType(modelName string) int { +// GetModelQuotaTypes 返回指定模型的计费类型集合(来自缓存) +func GetModelQuotaTypes(modelName string) []int { GetPricing() modelEnableGroupsLock.RLock() quota, ok := modelQuotaTypeMap[modelName] modelEnableGroupsLock.RUnlock() if !ok { - return 0 + return []int{} } - return quota + return []int{quota} } diff --git a/model/model_meta.go b/model/model_meta.go index 205c8975..b7602b0e 100644 --- a/model/model_meta.go +++ b/model/model_meta.go @@ -3,30 +3,15 @@ package model import ( "one-api/common" "strconv" - "strings" "gorm.io/gorm" ) -// Model 用于存储模型的元数据,例如描述、标签等 -// ModelName 字段具有唯一性约束,确保每个模型只会出现一次 -// Tags 字段使用逗号分隔的字符串保存标签集合,后期可根据需要扩展为 JSON 类型 -// Status: 1 表示启用,0 表示禁用,保留以便后续功能扩展 -// CreatedTime 和 UpdatedTime 使用 Unix 时间戳(秒)保存方便跨数据库移植 -// DeletedAt 采用 GORM 的软删除特性,便于后续数据恢复 -// -// 该表设计遵循第三范式(3NF): -// 1. 每一列都与主键(Id 或 ModelName)直接相关 -// 2. 不存在部分依赖(ModelName 是唯一键) -// 3. 不存在传递依赖(描述、标签等都依赖于 ModelName,而非依赖于其他非主键列) -// 这样既保证了数据一致性,也方便后期扩展 - -// 模型名称匹配规则 const ( - NameRuleExact = iota // 0 精确匹配 - NameRulePrefix // 1 前缀匹配 - NameRuleContains // 2 包含匹配 - NameRuleSuffix // 3 后缀匹配 + NameRuleExact = iota + NameRulePrefix + NameRuleContains + NameRuleSuffix ) type BoundChannel struct { @@ -49,14 +34,13 @@ type Model struct { BoundChannels []BoundChannel `json:"bound_channels,omitempty" gorm:"-"` EnableGroups []string `json:"enable_groups,omitempty" gorm:"-"` - QuotaType int `json:"quota_type" gorm:"-"` + QuotaTypes []int `json:"quota_types,omitempty" gorm:"-"` NameRule int `json:"name_rule" gorm:"default:0"` MatchedModels []string `json:"matched_models,omitempty" gorm:"-"` MatchedCount int `json:"matched_count,omitempty" gorm:"-"` } -// Insert 创建新的模型元数据记录 func (mi *Model) Insert() error { now := common.GetTimestamp() mi.CreatedTime = now @@ -64,7 +48,6 @@ func (mi *Model) Insert() error { return DB.Create(mi).Error } -// IsModelNameDuplicated 检查模型名称是否重复(排除自身 ID) func IsModelNameDuplicated(id int, name string) (bool, error) { if name == "" { return false, nil @@ -74,10 +57,8 @@ func IsModelNameDuplicated(id int, name string) (bool, error) { return cnt > 0, err } -// Update 更新现有模型记录 func (mi *Model) Update() error { mi.UpdatedTime = common.GetTimestamp() - // 使用 Session 配置并选择所有字段,允许零值(如空字符串)也能被更新 return DB.Session(&gorm.Session{AllowGlobalUpdate: false, FullSaveAssociations: false}). Model(&Model{}). Where("id = ?", mi.Id). @@ -86,22 +67,10 @@ func (mi *Model) Update() error { Updates(mi).Error } -// Delete 软删除模型记录 func (mi *Model) Delete() error { return DB.Delete(mi).Error } -// GetModelByName 根据模型名称查询元数据 -func GetModelByName(name string) (*Model, error) { - var mi Model - err := DB.Where("model_name = ?", name).First(&mi).Error - if err != nil { - return nil, err - } - return &mi, nil -} - -// GetVendorModelCounts 统计每个供应商下模型数量(不受分页影响) func GetVendorModelCounts() (map[int64]int64, error) { var stats []struct { VendorID int64 @@ -120,87 +89,38 @@ func GetVendorModelCounts() (map[int64]int64, error) { return m, nil } -// GetAllModels 分页获取所有模型元数据 func GetAllModels(offset int, limit int) ([]*Model, error) { var models []*Model - err := DB.Offset(offset).Limit(limit).Find(&models).Error + err := DB.Order("id DESC").Offset(offset).Limit(limit).Find(&models).Error return models, err } -// GetBoundChannels 查询支持该模型的渠道(名称+类型) -func GetBoundChannels(modelName string) ([]BoundChannel, error) { - var channels []BoundChannel - err := DB.Table("channels"). - Select("channels.name, channels.type"). - Joins("join abilities on abilities.channel_id = channels.id"). - Where("abilities.model = ? AND abilities.enabled = ?", modelName, true). - Group("channels.id"). - Scan(&channels).Error - return channels, err -} - -// GetBoundChannelsForModels 批量查询多模型的绑定渠道并去重返回 -func GetBoundChannelsForModels(modelNames []string) ([]BoundChannel, error) { +func GetBoundChannelsByModelsMap(modelNames []string) (map[string][]BoundChannel, error) { + result := make(map[string][]BoundChannel) if len(modelNames) == 0 { - return make([]BoundChannel, 0), nil + return result, nil } - var channels []BoundChannel + type row struct { + Model string + Name string + Type int + } + var rows []row err := DB.Table("channels"). - Select("channels.name, channels.type"). - Joins("join abilities on abilities.channel_id = channels.id"). + Select("abilities.model as model, channels.name as name, channels.type as type"). + Joins("JOIN abilities ON abilities.channel_id = channels.id"). Where("abilities.model IN ? AND abilities.enabled = ?", modelNames, true). - Group("channels.id"). - Scan(&channels).Error - return channels, err -} - -// FindModelByNameWithRule 根据模型名称和匹配规则查找模型元数据,优先级:精确 > 前缀 > 后缀 > 包含 -func FindModelByNameWithRule(name string) (*Model, error) { - // 1. 精确匹配 - if m, err := GetModelByName(name); err == nil { - return m, nil - } - // 2. 规则匹配 - var models []*Model - if err := DB.Where("name_rule <> ?", NameRuleExact).Find(&models).Error; err != nil { + Distinct(). + Scan(&rows).Error + if err != nil { return nil, err } - var prefixMatch, suffixMatch, containsMatch *Model - for _, m := range models { - switch m.NameRule { - case NameRulePrefix: - if strings.HasPrefix(name, m.ModelName) { - if prefixMatch == nil || len(m.ModelName) > len(prefixMatch.ModelName) { - prefixMatch = m - } - } - case NameRuleSuffix: - if strings.HasSuffix(name, m.ModelName) { - if suffixMatch == nil || len(m.ModelName) > len(suffixMatch.ModelName) { - suffixMatch = m - } - } - case NameRuleContains: - if strings.Contains(name, m.ModelName) { - if containsMatch == nil || len(m.ModelName) > len(containsMatch.ModelName) { - containsMatch = m - } - } - } + for _, r := range rows { + result[r.Model] = append(result[r.Model], BoundChannel{Name: r.Name, Type: r.Type}) } - if prefixMatch != nil { - return prefixMatch, nil - } - if suffixMatch != nil { - return suffixMatch, nil - } - if containsMatch != nil { - return containsMatch, nil - } - return nil, gorm.ErrRecordNotFound + return result, nil } -// SearchModels 根据关键词和供应商搜索模型,支持分页 func SearchModels(keyword string, vendor string, offset int, limit int) ([]*Model, int64, error) { var models []*Model db := DB.Model(&Model{}) @@ -209,7 +129,6 @@ func SearchModels(keyword string, vendor string, offset int, limit int) ([]*Mode db = db.Where("model_name LIKE ? OR description LIKE ? OR tags LIKE ?", like, like, like) } if vendor != "" { - // 如果是数字,按供应商 ID 精确匹配;否则按名称模糊匹配 if vid, err := strconv.Atoi(vendor); err == nil { db = db.Where("models.vendor_id = ?", vid) } else { @@ -217,10 +136,11 @@ func SearchModels(keyword string, vendor string, offset int, limit int) ([]*Mode } } var total int64 - err := db.Count(&total).Error - if err != nil { + if err := db.Count(&total).Error; err != nil { return nil, 0, err } - err = db.Offset(offset).Limit(limit).Order("models.id DESC").Find(&models).Error - return models, total, err + if err := db.Order("models.id DESC").Offset(offset).Limit(limit).Find(&models).Error; err != nil { + return nil, 0, err + } + return models, total, nil } diff --git a/web/src/components/table/model-pricing/view/table/PricingTableColumns.js b/web/src/components/table/model-pricing/view/table/PricingTableColumns.js index 18e1fc89..e2b34cce 100644 --- a/web/src/components/table/model-pricing/view/table/PricingTableColumns.js +++ b/web/src/components/table/model-pricing/view/table/PricingTableColumns.js @@ -23,23 +23,31 @@ import { IconHelpCircle } from '@douyinfe/semi-icons'; import { renderModelTag, stringToColor, calculateModelPrice, getLobeHubIcon } from '../../../../../helpers'; import { renderLimitedItems, renderDescription } from '../../../../common/ui/RenderUtils'; -function renderQuotaType(type, t) { - switch (type) { - case 1: - return ( - - {t('按次计费')} - - ); - case 0: - return ( - - {t('按量计费')} - - ); - default: - return t('未知'); - } +function renderQuotaTypes(types, t) { + if (!Array.isArray(types) || types.length === 0) return '-'; + const renderOne = (type, idx) => { + switch (type) { + case 1: + return ( + + {t('按次计费')} + + ); + case 0: + return ( + + {t('按量计费')} + + ); + default: + return ( + + {type} + + ); + } + }; + return {types.map((t0, idx) => renderOne(t0, idx))}; } // Render vendor name @@ -122,11 +130,8 @@ export const getPricingTableColumns = ({ const quotaColumn = { title: t('计费类型'), - dataIndex: 'quota_type', - render: (text, record, index) => { - return renderQuotaType(parseInt(text), t); - }, - sorter: (a, b) => a.quota_type - b.quota_type, + dataIndex: 'quota_types', + render: (text, record, index) => renderQuotaTypes(text, t), }; const descriptionColumn = { @@ -170,11 +175,11 @@ export const getPricingTableColumns = ({ const content = (
- {t('模型倍率')}:{record.quota_type === 0 ? text : t('无')} + {t('模型倍率')}:{Array.isArray(record.quota_types) && record.quota_types.includes(0) ? text : t('无')}
{t('补全倍率')}: - {record.quota_type === 0 ? completionRatio : t('无')} + {Array.isArray(record.quota_types) && record.quota_types.includes(0) ? completionRatio : t('无')}
{t('分组倍率')}:{groupRatio[selectedGroup]} diff --git a/web/src/components/table/models/ModelsColumnDefs.js b/web/src/components/table/models/ModelsColumnDefs.js index e1fc257e..bee321d9 100644 --- a/web/src/components/table/models/ModelsColumnDefs.js +++ b/web/src/components/table/models/ModelsColumnDefs.js @@ -121,24 +121,36 @@ const renderEndpoints = (value) => { } }; -// Render quota type -const renderQuotaType = (qt, t) => { - if (qt === 1) { +// Render quota types (array) +const renderQuotaTypes = (arr, t) => { + if (!Array.isArray(arr) || arr.length === 0) return '-'; + const renderOne = (qt, idx) => { + if (qt === 1) { + return ( + + {t('按次计费')} + + ); + } + if (qt === 0) { + return ( + + {t('按量计费')} + + ); + } + // 未来新增模式的兜底展示 return ( - - {t('按次计费')} + + {qt} ); - } - if (qt === 0) { - return ( - - {t('按量计费')} - - ); - } - // 未知 - return '-'; + }; + return ( + + {arr.map((qt, idx) => renderOne(qt, idx))} + + ); }; // Render bound channels @@ -303,8 +315,8 @@ export const getModelsColumns = ({ }, { title: t('计费类型'), - dataIndex: 'quota_type', - render: (qt) => renderQuotaType(qt, t), + dataIndex: 'quota_types', + render: (qts) => renderQuotaTypes(qts, t), }, { title: t('创建时间'),