diff --git a/controller/model_meta.go b/controller/model_meta.go
index 9b9f7849..31ea64f3 100644
--- a/controller/model_meta.go
+++ b/controller/model_meta.go
@@ -2,6 +2,7 @@ package controller
import (
"encoding/json"
+ "sort"
"strconv"
"strings"
@@ -21,10 +22,8 @@ func GetAllModelsMeta(c *gin.Context) {
common.ApiError(c, err)
return
}
- // 填充附加字段
- for _, m := range modelsMeta {
- fillModelExtra(m)
- }
+ // 批量填充附加字段,提升列表接口性能
+ enrichModels(modelsMeta)
var total int64
model.DB.Model(&model.Model{}).Count(&total)
@@ -54,9 +53,8 @@ func SearchModelsMeta(c *gin.Context) {
common.ApiError(c, err)
return
}
- for _, m := range modelsMeta {
- fillModelExtra(m)
- }
+ // 批量填充附加字段,提升列表接口性能
+ enrichModels(modelsMeta)
pageInfo.SetTotal(int(total))
pageInfo.SetItems(modelsMeta)
common.ApiSuccess(c, pageInfo)
@@ -75,7 +73,7 @@ func GetModelMeta(c *gin.Context) {
common.ApiError(c, err)
return
}
- fillModelExtra(&m)
+ enrichModels([]*model.Model{&m})
common.ApiSuccess(c, &m)
}
@@ -162,104 +160,157 @@ func DeleteModelMeta(c *gin.Context) {
common.ApiSuccess(c, nil)
}
-// 辅助函数:填充 Endpoints 和 BoundChannels 和 EnableGroups
-func fillModelExtra(m *model.Model) {
- // 若为精确匹配,保持原有逻辑
- if m.NameRule == model.NameRuleExact {
- if m.Endpoints == "" {
- eps := model.GetModelSupportEndpointTypes(m.ModelName)
- if b, err := json.Marshal(eps); err == nil {
- m.Endpoints = string(b)
- }
- }
- if channels, err := model.GetBoundChannels(m.ModelName); err == nil {
- m.BoundChannels = channels
- }
- m.EnableGroups = model.GetModelEnableGroups(m.ModelName)
- m.QuotaType = model.GetModelQuotaType(m.ModelName)
+// enrichModels 批量填充附加信息:端点、渠道、分组、计费类型,避免 N+1 查询
+func enrichModels(models []*model.Model) {
+ if len(models) == 0 {
return
}
- // 非精确匹配:计算并集
- pricings := model.GetPricing()
-
- // 匹配到的模型名称集合
- matchedNames := make([]string, 0)
-
- // 端点去重集合
- endpointSet := make(map[constant.EndpointType]struct{})
-
- // 已绑定渠道去重集合
- channelSet := make(map[string]model.BoundChannel)
- // 分组去重集合
- groupSet := make(map[string]struct{})
- // 计费类型(若有任意模型为 1,则返回 1)
- quotaTypeSet := make(map[int]struct{})
-
- for _, p := range pricings {
- var matched bool
- switch m.NameRule {
- case model.NameRulePrefix:
- matched = strings.HasPrefix(p.ModelName, m.ModelName)
- case model.NameRuleSuffix:
- matched = strings.HasSuffix(p.ModelName, m.ModelName)
- case model.NameRuleContains:
- matched = strings.Contains(p.ModelName, m.ModelName)
- }
- if !matched {
+ // 1) 拆分精确与规则匹配
+ exactNames := make([]string, 0)
+ exactIdx := make(map[string][]int) // modelName -> indices in models
+ ruleIndices := make([]int, 0)
+ for i, m := range models {
+ if m == nil {
continue
}
-
- // 记录匹配到的模型名称
- matchedNames = append(matchedNames, p.ModelName)
-
- // 收集端点
- for _, et := range p.SupportedEndpointTypes {
- endpointSet[et] = struct{}{}
- }
-
- // 收集分组
- for _, g := range p.EnableGroup {
- groupSet[g] = struct{}{}
- }
-
- // 收集计费类型
- quotaTypeSet[p.QuotaType] = struct{}{}
- }
-
- // 序列化端点
- if len(endpointSet) > 0 && m.Endpoints == "" {
- eps := make([]constant.EndpointType, 0, len(endpointSet))
- for et := range endpointSet {
- eps = append(eps, et)
- }
- if b, err := json.Marshal(eps); err == nil {
- m.Endpoints = string(b)
+ if m.NameRule == model.NameRuleExact {
+ exactNames = append(exactNames, m.ModelName)
+ exactIdx[m.ModelName] = append(exactIdx[m.ModelName], i)
+ } else {
+ ruleIndices = append(ruleIndices, i)
}
}
- // 序列化分组
- if len(groupSet) > 0 {
- groups := make([]string, 0, len(groupSet))
- for g := range groupSet {
- groups = append(groups, g)
+ // 2) 批量查询精确模型的绑定渠道
+ channelsByModel, _ := model.GetBoundChannelsByModelsMap(exactNames)
+
+ // 3) 精确模型:端点从缓存、渠道批量映射、分组/计费类型从缓存
+ for name, indices := range exactIdx {
+ chs := channelsByModel[name]
+ for _, idx := range indices {
+ mm := models[idx]
+ if mm.Endpoints == "" {
+ eps := model.GetModelSupportEndpointTypes(mm.ModelName)
+ if b, err := json.Marshal(eps); err == nil {
+ mm.Endpoints = string(b)
+ }
+ }
+ mm.BoundChannels = chs
+ mm.EnableGroups = model.GetModelEnableGroups(mm.ModelName)
+ mm.QuotaTypes = model.GetModelQuotaTypes(mm.ModelName)
}
- m.EnableGroups = groups
}
- // 确定计费类型:仅当所有匹配模型计费类型一致时才返回该类型,否则返回 -1 表示未知/不确定
- if len(quotaTypeSet) == 1 {
- for k := range quotaTypeSet {
- m.QuotaType = k
- }
- } else {
- m.QuotaType = -1
+ if len(ruleIndices) == 0 {
+ return
}
- // 批量查询并序列化渠道
- if len(matchedNames) > 0 {
- if channels, err := model.GetBoundChannelsForModels(matchedNames); err == nil {
- for _, ch := range channels {
+ // 4) 一次性读取定价缓存,内存匹配所有规则模型
+ pricings := model.GetPricing()
+
+ // 为全部规则模型收集匹配名集合、端点并集、分组并集、配额集合
+ matchedNamesByIdx := make(map[int][]string)
+ endpointSetByIdx := make(map[int]map[constant.EndpointType]struct{})
+ groupSetByIdx := make(map[int]map[string]struct{})
+ quotaSetByIdx := make(map[int]map[int]struct{})
+
+ for _, p := range pricings {
+ for _, idx := range ruleIndices {
+ mm := models[idx]
+ var matched bool
+ switch mm.NameRule {
+ case model.NameRulePrefix:
+ matched = strings.HasPrefix(p.ModelName, mm.ModelName)
+ case model.NameRuleSuffix:
+ matched = strings.HasSuffix(p.ModelName, mm.ModelName)
+ case model.NameRuleContains:
+ matched = strings.Contains(p.ModelName, mm.ModelName)
+ }
+ if !matched {
+ continue
+ }
+ matchedNamesByIdx[idx] = append(matchedNamesByIdx[idx], p.ModelName)
+
+ es := endpointSetByIdx[idx]
+ if es == nil {
+ es = make(map[constant.EndpointType]struct{})
+ endpointSetByIdx[idx] = es
+ }
+ for _, et := range p.SupportedEndpointTypes {
+ es[et] = struct{}{}
+ }
+
+ gs := groupSetByIdx[idx]
+ if gs == nil {
+ gs = make(map[string]struct{})
+ groupSetByIdx[idx] = gs
+ }
+ for _, g := range p.EnableGroup {
+ gs[g] = struct{}{}
+ }
+
+ qs := quotaSetByIdx[idx]
+ if qs == nil {
+ qs = make(map[int]struct{})
+ quotaSetByIdx[idx] = qs
+ }
+ qs[p.QuotaType] = struct{}{}
+ }
+ }
+
+ // 5) 汇总所有匹配到的模型名称,批量查询一次渠道
+ allMatchedSet := make(map[string]struct{})
+ for _, names := range matchedNamesByIdx {
+ for _, n := range names {
+ allMatchedSet[n] = struct{}{}
+ }
+ }
+ allMatched := make([]string, 0, len(allMatchedSet))
+ for n := range allMatchedSet {
+ allMatched = append(allMatched, n)
+ }
+ matchedChannelsByModel, _ := model.GetBoundChannelsByModelsMap(allMatched)
+
+ // 6) 回填每个规则模型的并集信息
+ for _, idx := range ruleIndices {
+ mm := models[idx]
+
+ // 端点并集 -> 序列化
+ if es, ok := endpointSetByIdx[idx]; ok && mm.Endpoints == "" {
+ eps := make([]constant.EndpointType, 0, len(es))
+ for et := range es {
+ eps = append(eps, et)
+ }
+ if b, err := json.Marshal(eps); err == nil {
+ mm.Endpoints = string(b)
+ }
+ }
+
+ // 分组并集
+ if gs, ok := groupSetByIdx[idx]; ok {
+ groups := make([]string, 0, len(gs))
+ for g := range gs {
+ groups = append(groups, g)
+ }
+ mm.EnableGroups = groups
+ }
+
+ // 配额类型集合(保持去重并排序)
+ if qs, ok := quotaSetByIdx[idx]; ok {
+ arr := make([]int, 0, len(qs))
+ for k := range qs {
+ arr = append(arr, k)
+ }
+ sort.Ints(arr)
+ mm.QuotaTypes = arr
+ }
+
+ // 渠道并集
+ names := matchedNamesByIdx[idx]
+ channelSet := make(map[string]model.BoundChannel)
+ for _, n := range names {
+ for _, ch := range matchedChannelsByModel[n] {
key := ch.Name + "_" + strconv.Itoa(ch.Type)
channelSet[key] = ch
}
@@ -269,11 +320,11 @@ func fillModelExtra(m *model.Model) {
for _, ch := range channelSet {
chs = append(chs, ch)
}
- m.BoundChannels = chs
+ mm.BoundChannels = chs
}
- }
- // 设置匹配信息
- m.MatchedModels = matchedNames
- m.MatchedCount = len(matchedNames)
+ // 匹配信息
+ mm.MatchedModels = names
+ mm.MatchedCount = len(names)
+ }
}
diff --git a/model/model_extra.go b/model/model_extra.go
index 495b5bb9..71fd84e7 100644
--- a/model/model_extra.go
+++ b/model/model_extra.go
@@ -1,7 +1,5 @@
package model
-// GetModelEnableGroups 返回指定模型名称可用的用户分组列表。
-// 使用在 updatePricing() 中维护的缓存映射,O(1) 读取,适合高并发场景。
func GetModelEnableGroups(modelName string) []string {
// 确保缓存最新
GetPricing()
@@ -19,16 +17,15 @@ func GetModelEnableGroups(modelName string) []string {
return groups
}
-// GetModelQuotaType 返回指定模型的计费类型(quota_type)。
-// 同样使用缓存映射,避免每次遍历定价切片。
-func GetModelQuotaType(modelName string) int {
+// GetModelQuotaTypes 返回指定模型的计费类型集合(来自缓存)
+func GetModelQuotaTypes(modelName string) []int {
GetPricing()
modelEnableGroupsLock.RLock()
quota, ok := modelQuotaTypeMap[modelName]
modelEnableGroupsLock.RUnlock()
if !ok {
- return 0
+ return []int{}
}
- return quota
+ return []int{quota}
}
diff --git a/model/model_meta.go b/model/model_meta.go
index 205c8975..b7602b0e 100644
--- a/model/model_meta.go
+++ b/model/model_meta.go
@@ -3,30 +3,15 @@ package model
import (
"one-api/common"
"strconv"
- "strings"
"gorm.io/gorm"
)
-// Model 用于存储模型的元数据,例如描述、标签等
-// ModelName 字段具有唯一性约束,确保每个模型只会出现一次
-// Tags 字段使用逗号分隔的字符串保存标签集合,后期可根据需要扩展为 JSON 类型
-// Status: 1 表示启用,0 表示禁用,保留以便后续功能扩展
-// CreatedTime 和 UpdatedTime 使用 Unix 时间戳(秒)保存方便跨数据库移植
-// DeletedAt 采用 GORM 的软删除特性,便于后续数据恢复
-//
-// 该表设计遵循第三范式(3NF):
-// 1. 每一列都与主键(Id 或 ModelName)直接相关
-// 2. 不存在部分依赖(ModelName 是唯一键)
-// 3. 不存在传递依赖(描述、标签等都依赖于 ModelName,而非依赖于其他非主键列)
-// 这样既保证了数据一致性,也方便后期扩展
-
-// 模型名称匹配规则
const (
- NameRuleExact = iota // 0 精确匹配
- NameRulePrefix // 1 前缀匹配
- NameRuleContains // 2 包含匹配
- NameRuleSuffix // 3 后缀匹配
+ NameRuleExact = iota
+ NameRulePrefix
+ NameRuleContains
+ NameRuleSuffix
)
type BoundChannel struct {
@@ -49,14 +34,13 @@ type Model struct {
BoundChannels []BoundChannel `json:"bound_channels,omitempty" gorm:"-"`
EnableGroups []string `json:"enable_groups,omitempty" gorm:"-"`
- QuotaType int `json:"quota_type" gorm:"-"`
+ QuotaTypes []int `json:"quota_types,omitempty" gorm:"-"`
NameRule int `json:"name_rule" gorm:"default:0"`
MatchedModels []string `json:"matched_models,omitempty" gorm:"-"`
MatchedCount int `json:"matched_count,omitempty" gorm:"-"`
}
-// Insert 创建新的模型元数据记录
func (mi *Model) Insert() error {
now := common.GetTimestamp()
mi.CreatedTime = now
@@ -64,7 +48,6 @@ func (mi *Model) Insert() error {
return DB.Create(mi).Error
}
-// IsModelNameDuplicated 检查模型名称是否重复(排除自身 ID)
func IsModelNameDuplicated(id int, name string) (bool, error) {
if name == "" {
return false, nil
@@ -74,10 +57,8 @@ func IsModelNameDuplicated(id int, name string) (bool, error) {
return cnt > 0, err
}
-// Update 更新现有模型记录
func (mi *Model) Update() error {
mi.UpdatedTime = common.GetTimestamp()
- // 使用 Session 配置并选择所有字段,允许零值(如空字符串)也能被更新
return DB.Session(&gorm.Session{AllowGlobalUpdate: false, FullSaveAssociations: false}).
Model(&Model{}).
Where("id = ?", mi.Id).
@@ -86,22 +67,10 @@ func (mi *Model) Update() error {
Updates(mi).Error
}
-// Delete 软删除模型记录
func (mi *Model) Delete() error {
return DB.Delete(mi).Error
}
-// GetModelByName 根据模型名称查询元数据
-func GetModelByName(name string) (*Model, error) {
- var mi Model
- err := DB.Where("model_name = ?", name).First(&mi).Error
- if err != nil {
- return nil, err
- }
- return &mi, nil
-}
-
-// GetVendorModelCounts 统计每个供应商下模型数量(不受分页影响)
func GetVendorModelCounts() (map[int64]int64, error) {
var stats []struct {
VendorID int64
@@ -120,87 +89,38 @@ func GetVendorModelCounts() (map[int64]int64, error) {
return m, nil
}
-// GetAllModels 分页获取所有模型元数据
func GetAllModels(offset int, limit int) ([]*Model, error) {
var models []*Model
- err := DB.Offset(offset).Limit(limit).Find(&models).Error
+ err := DB.Order("id DESC").Offset(offset).Limit(limit).Find(&models).Error
return models, err
}
-// GetBoundChannels 查询支持该模型的渠道(名称+类型)
-func GetBoundChannels(modelName string) ([]BoundChannel, error) {
- var channels []BoundChannel
- err := DB.Table("channels").
- Select("channels.name, channels.type").
- Joins("join abilities on abilities.channel_id = channels.id").
- Where("abilities.model = ? AND abilities.enabled = ?", modelName, true).
- Group("channels.id").
- Scan(&channels).Error
- return channels, err
-}
-
-// GetBoundChannelsForModels 批量查询多模型的绑定渠道并去重返回
-func GetBoundChannelsForModels(modelNames []string) ([]BoundChannel, error) {
+func GetBoundChannelsByModelsMap(modelNames []string) (map[string][]BoundChannel, error) {
+ result := make(map[string][]BoundChannel)
if len(modelNames) == 0 {
- return make([]BoundChannel, 0), nil
+ return result, nil
}
- var channels []BoundChannel
+ type row struct {
+ Model string
+ Name string
+ Type int
+ }
+ var rows []row
err := DB.Table("channels").
- Select("channels.name, channels.type").
- Joins("join abilities on abilities.channel_id = channels.id").
+ Select("abilities.model as model, channels.name as name, channels.type as type").
+ Joins("JOIN abilities ON abilities.channel_id = channels.id").
Where("abilities.model IN ? AND abilities.enabled = ?", modelNames, true).
- Group("channels.id").
- Scan(&channels).Error
- return channels, err
-}
-
-// FindModelByNameWithRule 根据模型名称和匹配规则查找模型元数据,优先级:精确 > 前缀 > 后缀 > 包含
-func FindModelByNameWithRule(name string) (*Model, error) {
- // 1. 精确匹配
- if m, err := GetModelByName(name); err == nil {
- return m, nil
- }
- // 2. 规则匹配
- var models []*Model
- if err := DB.Where("name_rule <> ?", NameRuleExact).Find(&models).Error; err != nil {
+ Distinct().
+ Scan(&rows).Error
+ if err != nil {
return nil, err
}
- var prefixMatch, suffixMatch, containsMatch *Model
- for _, m := range models {
- switch m.NameRule {
- case NameRulePrefix:
- if strings.HasPrefix(name, m.ModelName) {
- if prefixMatch == nil || len(m.ModelName) > len(prefixMatch.ModelName) {
- prefixMatch = m
- }
- }
- case NameRuleSuffix:
- if strings.HasSuffix(name, m.ModelName) {
- if suffixMatch == nil || len(m.ModelName) > len(suffixMatch.ModelName) {
- suffixMatch = m
- }
- }
- case NameRuleContains:
- if strings.Contains(name, m.ModelName) {
- if containsMatch == nil || len(m.ModelName) > len(containsMatch.ModelName) {
- containsMatch = m
- }
- }
- }
+ for _, r := range rows {
+ result[r.Model] = append(result[r.Model], BoundChannel{Name: r.Name, Type: r.Type})
}
- if prefixMatch != nil {
- return prefixMatch, nil
- }
- if suffixMatch != nil {
- return suffixMatch, nil
- }
- if containsMatch != nil {
- return containsMatch, nil
- }
- return nil, gorm.ErrRecordNotFound
+ return result, nil
}
-// SearchModels 根据关键词和供应商搜索模型,支持分页
func SearchModels(keyword string, vendor string, offset int, limit int) ([]*Model, int64, error) {
var models []*Model
db := DB.Model(&Model{})
@@ -209,7 +129,6 @@ func SearchModels(keyword string, vendor string, offset int, limit int) ([]*Mode
db = db.Where("model_name LIKE ? OR description LIKE ? OR tags LIKE ?", like, like, like)
}
if vendor != "" {
- // 如果是数字,按供应商 ID 精确匹配;否则按名称模糊匹配
if vid, err := strconv.Atoi(vendor); err == nil {
db = db.Where("models.vendor_id = ?", vid)
} else {
@@ -217,10 +136,11 @@ func SearchModels(keyword string, vendor string, offset int, limit int) ([]*Mode
}
}
var total int64
- err := db.Count(&total).Error
- if err != nil {
+ if err := db.Count(&total).Error; err != nil {
return nil, 0, err
}
- err = db.Offset(offset).Limit(limit).Order("models.id DESC").Find(&models).Error
- return models, total, err
+ if err := db.Order("models.id DESC").Offset(offset).Limit(limit).Find(&models).Error; err != nil {
+ return nil, 0, err
+ }
+ return models, total, nil
}
diff --git a/web/src/components/table/model-pricing/view/table/PricingTableColumns.js b/web/src/components/table/model-pricing/view/table/PricingTableColumns.js
index 18e1fc89..e2b34cce 100644
--- a/web/src/components/table/model-pricing/view/table/PricingTableColumns.js
+++ b/web/src/components/table/model-pricing/view/table/PricingTableColumns.js
@@ -23,23 +23,31 @@ import { IconHelpCircle } from '@douyinfe/semi-icons';
import { renderModelTag, stringToColor, calculateModelPrice, getLobeHubIcon } from '../../../../../helpers';
import { renderLimitedItems, renderDescription } from '../../../../common/ui/RenderUtils';
-function renderQuotaType(type, t) {
- switch (type) {
- case 1:
- return (
-