Backend - Add GetBoundChannelsByModelsMap to batch-fetch bound channels via a single JOIN (Distinct), compatible with SQLite/MySQL/PostgreSQL - Replace per-record enrichment with a single-pass enrichModels to avoid N+1 queries; compute unions for prefix/suffix/contains matches in memory - Change Model.QuotaType to QuotaTypes []int and expose quota_types in responses - Add GetModelQuotaTypes for cached O(1) lookups; exact models return a single-element array - Sort quota_types for stable output order - Remove unused code: GetModelByName, GetBoundChannels, GetBoundChannelsForModels, FindModelByNameWithRule, buildPrefixes, buildSuffixes - Clean up redundant comments, keeping concise and readable code Frontend - Models table: switch to quota_types, render multiple billing modes ([0], [1], [0,1], future values supported) - Pricing table: switch to quota_types; ratio display now checks quota_types.includes(0); array rendering for billing tags Compatibility - SQL uses standard JOIN/IN/Distinct; works across SQLite/MySQL/PostgreSQL - Lint passes; no DB schema changes (quota_types is a JSON response field only) Breaking Change - API field renamed: quota_type -> quota_types (array). Update clients accordingly.
331 lines
8.0 KiB
Go
331 lines
8.0 KiB
Go
package controller
|
|
|
|
import (
|
|
"encoding/json"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"one-api/common"
|
|
"one-api/constant"
|
|
"one-api/model"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
// GetAllModelsMeta 获取模型列表(分页)
|
|
func GetAllModelsMeta(c *gin.Context) {
|
|
|
|
pageInfo := common.GetPageQuery(c)
|
|
modelsMeta, err := model.GetAllModels(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
|
if err != nil {
|
|
common.ApiError(c, err)
|
|
return
|
|
}
|
|
// 批量填充附加字段,提升列表接口性能
|
|
enrichModels(modelsMeta)
|
|
var total int64
|
|
model.DB.Model(&model.Model{}).Count(&total)
|
|
|
|
// 统计供应商计数(全部数据,不受分页影响)
|
|
vendorCounts, _ := model.GetVendorModelCounts()
|
|
|
|
pageInfo.SetTotal(int(total))
|
|
pageInfo.SetItems(modelsMeta)
|
|
common.ApiSuccess(c, gin.H{
|
|
"items": modelsMeta,
|
|
"total": total,
|
|
"page": pageInfo.GetPage(),
|
|
"page_size": pageInfo.GetPageSize(),
|
|
"vendor_counts": vendorCounts,
|
|
})
|
|
}
|
|
|
|
// SearchModelsMeta 搜索模型列表
|
|
func SearchModelsMeta(c *gin.Context) {
|
|
|
|
keyword := c.Query("keyword")
|
|
vendor := c.Query("vendor")
|
|
pageInfo := common.GetPageQuery(c)
|
|
|
|
modelsMeta, total, err := model.SearchModels(keyword, vendor, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
|
if err != nil {
|
|
common.ApiError(c, err)
|
|
return
|
|
}
|
|
// 批量填充附加字段,提升列表接口性能
|
|
enrichModels(modelsMeta)
|
|
pageInfo.SetTotal(int(total))
|
|
pageInfo.SetItems(modelsMeta)
|
|
common.ApiSuccess(c, pageInfo)
|
|
}
|
|
|
|
// GetModelMeta 根据 ID 获取单条模型信息
|
|
func GetModelMeta(c *gin.Context) {
|
|
idStr := c.Param("id")
|
|
id, err := strconv.Atoi(idStr)
|
|
if err != nil {
|
|
common.ApiError(c, err)
|
|
return
|
|
}
|
|
var m model.Model
|
|
if err := model.DB.First(&m, id).Error; err != nil {
|
|
common.ApiError(c, err)
|
|
return
|
|
}
|
|
enrichModels([]*model.Model{&m})
|
|
common.ApiSuccess(c, &m)
|
|
}
|
|
|
|
// CreateModelMeta 新建模型
|
|
func CreateModelMeta(c *gin.Context) {
|
|
var m model.Model
|
|
if err := c.ShouldBindJSON(&m); err != nil {
|
|
common.ApiError(c, err)
|
|
return
|
|
}
|
|
if m.ModelName == "" {
|
|
common.ApiErrorMsg(c, "模型名称不能为空")
|
|
return
|
|
}
|
|
// 名称冲突检查
|
|
if dup, err := model.IsModelNameDuplicated(0, m.ModelName); err != nil {
|
|
common.ApiError(c, err)
|
|
return
|
|
} else if dup {
|
|
common.ApiErrorMsg(c, "模型名称已存在")
|
|
return
|
|
}
|
|
|
|
if err := m.Insert(); err != nil {
|
|
common.ApiError(c, err)
|
|
return
|
|
}
|
|
model.RefreshPricing()
|
|
common.ApiSuccess(c, &m)
|
|
}
|
|
|
|
// UpdateModelMeta 更新模型
|
|
func UpdateModelMeta(c *gin.Context) {
|
|
statusOnly := c.Query("status_only") == "true"
|
|
|
|
var m model.Model
|
|
if err := c.ShouldBindJSON(&m); err != nil {
|
|
common.ApiError(c, err)
|
|
return
|
|
}
|
|
if m.Id == 0 {
|
|
common.ApiErrorMsg(c, "缺少模型 ID")
|
|
return
|
|
}
|
|
|
|
if statusOnly {
|
|
// 只更新状态,防止误清空其他字段
|
|
if err := model.DB.Model(&model.Model{}).Where("id = ?", m.Id).Update("status", m.Status).Error; err != nil {
|
|
common.ApiError(c, err)
|
|
return
|
|
}
|
|
} else {
|
|
// 名称冲突检查
|
|
if dup, err := model.IsModelNameDuplicated(m.Id, m.ModelName); err != nil {
|
|
common.ApiError(c, err)
|
|
return
|
|
} else if dup {
|
|
common.ApiErrorMsg(c, "模型名称已存在")
|
|
return
|
|
}
|
|
|
|
if err := m.Update(); err != nil {
|
|
common.ApiError(c, err)
|
|
return
|
|
}
|
|
}
|
|
model.RefreshPricing()
|
|
common.ApiSuccess(c, &m)
|
|
}
|
|
|
|
// DeleteModelMeta 删除模型
|
|
func DeleteModelMeta(c *gin.Context) {
|
|
idStr := c.Param("id")
|
|
id, err := strconv.Atoi(idStr)
|
|
if err != nil {
|
|
common.ApiError(c, err)
|
|
return
|
|
}
|
|
if err := model.DB.Delete(&model.Model{}, id).Error; err != nil {
|
|
common.ApiError(c, err)
|
|
return
|
|
}
|
|
model.RefreshPricing()
|
|
common.ApiSuccess(c, nil)
|
|
}
|
|
|
|
// enrichModels 批量填充附加信息:端点、渠道、分组、计费类型,避免 N+1 查询
|
|
func enrichModels(models []*model.Model) {
|
|
if len(models) == 0 {
|
|
return
|
|
}
|
|
|
|
// 1) 拆分精确与规则匹配
|
|
exactNames := make([]string, 0)
|
|
exactIdx := make(map[string][]int) // modelName -> indices in models
|
|
ruleIndices := make([]int, 0)
|
|
for i, m := range models {
|
|
if m == nil {
|
|
continue
|
|
}
|
|
if m.NameRule == model.NameRuleExact {
|
|
exactNames = append(exactNames, m.ModelName)
|
|
exactIdx[m.ModelName] = append(exactIdx[m.ModelName], i)
|
|
} else {
|
|
ruleIndices = append(ruleIndices, i)
|
|
}
|
|
}
|
|
|
|
// 2) 批量查询精确模型的绑定渠道
|
|
channelsByModel, _ := model.GetBoundChannelsByModelsMap(exactNames)
|
|
|
|
// 3) 精确模型:端点从缓存、渠道批量映射、分组/计费类型从缓存
|
|
for name, indices := range exactIdx {
|
|
chs := channelsByModel[name]
|
|
for _, idx := range indices {
|
|
mm := models[idx]
|
|
if mm.Endpoints == "" {
|
|
eps := model.GetModelSupportEndpointTypes(mm.ModelName)
|
|
if b, err := json.Marshal(eps); err == nil {
|
|
mm.Endpoints = string(b)
|
|
}
|
|
}
|
|
mm.BoundChannels = chs
|
|
mm.EnableGroups = model.GetModelEnableGroups(mm.ModelName)
|
|
mm.QuotaTypes = model.GetModelQuotaTypes(mm.ModelName)
|
|
}
|
|
}
|
|
|
|
if len(ruleIndices) == 0 {
|
|
return
|
|
}
|
|
|
|
// 4) 一次性读取定价缓存,内存匹配所有规则模型
|
|
pricings := model.GetPricing()
|
|
|
|
// 为全部规则模型收集匹配名集合、端点并集、分组并集、配额集合
|
|
matchedNamesByIdx := make(map[int][]string)
|
|
endpointSetByIdx := make(map[int]map[constant.EndpointType]struct{})
|
|
groupSetByIdx := make(map[int]map[string]struct{})
|
|
quotaSetByIdx := make(map[int]map[int]struct{})
|
|
|
|
for _, p := range pricings {
|
|
for _, idx := range ruleIndices {
|
|
mm := models[idx]
|
|
var matched bool
|
|
switch mm.NameRule {
|
|
case model.NameRulePrefix:
|
|
matched = strings.HasPrefix(p.ModelName, mm.ModelName)
|
|
case model.NameRuleSuffix:
|
|
matched = strings.HasSuffix(p.ModelName, mm.ModelName)
|
|
case model.NameRuleContains:
|
|
matched = strings.Contains(p.ModelName, mm.ModelName)
|
|
}
|
|
if !matched {
|
|
continue
|
|
}
|
|
matchedNamesByIdx[idx] = append(matchedNamesByIdx[idx], p.ModelName)
|
|
|
|
es := endpointSetByIdx[idx]
|
|
if es == nil {
|
|
es = make(map[constant.EndpointType]struct{})
|
|
endpointSetByIdx[idx] = es
|
|
}
|
|
for _, et := range p.SupportedEndpointTypes {
|
|
es[et] = struct{}{}
|
|
}
|
|
|
|
gs := groupSetByIdx[idx]
|
|
if gs == nil {
|
|
gs = make(map[string]struct{})
|
|
groupSetByIdx[idx] = gs
|
|
}
|
|
for _, g := range p.EnableGroup {
|
|
gs[g] = struct{}{}
|
|
}
|
|
|
|
qs := quotaSetByIdx[idx]
|
|
if qs == nil {
|
|
qs = make(map[int]struct{})
|
|
quotaSetByIdx[idx] = qs
|
|
}
|
|
qs[p.QuotaType] = struct{}{}
|
|
}
|
|
}
|
|
|
|
// 5) 汇总所有匹配到的模型名称,批量查询一次渠道
|
|
allMatchedSet := make(map[string]struct{})
|
|
for _, names := range matchedNamesByIdx {
|
|
for _, n := range names {
|
|
allMatchedSet[n] = struct{}{}
|
|
}
|
|
}
|
|
allMatched := make([]string, 0, len(allMatchedSet))
|
|
for n := range allMatchedSet {
|
|
allMatched = append(allMatched, n)
|
|
}
|
|
matchedChannelsByModel, _ := model.GetBoundChannelsByModelsMap(allMatched)
|
|
|
|
// 6) 回填每个规则模型的并集信息
|
|
for _, idx := range ruleIndices {
|
|
mm := models[idx]
|
|
|
|
// 端点并集 -> 序列化
|
|
if es, ok := endpointSetByIdx[idx]; ok && mm.Endpoints == "" {
|
|
eps := make([]constant.EndpointType, 0, len(es))
|
|
for et := range es {
|
|
eps = append(eps, et)
|
|
}
|
|
if b, err := json.Marshal(eps); err == nil {
|
|
mm.Endpoints = string(b)
|
|
}
|
|
}
|
|
|
|
// 分组并集
|
|
if gs, ok := groupSetByIdx[idx]; ok {
|
|
groups := make([]string, 0, len(gs))
|
|
for g := range gs {
|
|
groups = append(groups, g)
|
|
}
|
|
mm.EnableGroups = groups
|
|
}
|
|
|
|
// 配额类型集合(保持去重并排序)
|
|
if qs, ok := quotaSetByIdx[idx]; ok {
|
|
arr := make([]int, 0, len(qs))
|
|
for k := range qs {
|
|
arr = append(arr, k)
|
|
}
|
|
sort.Ints(arr)
|
|
mm.QuotaTypes = arr
|
|
}
|
|
|
|
// 渠道并集
|
|
names := matchedNamesByIdx[idx]
|
|
channelSet := make(map[string]model.BoundChannel)
|
|
for _, n := range names {
|
|
for _, ch := range matchedChannelsByModel[n] {
|
|
key := ch.Name + "_" + strconv.Itoa(ch.Type)
|
|
channelSet[key] = ch
|
|
}
|
|
}
|
|
if len(channelSet) > 0 {
|
|
chs := make([]model.BoundChannel, 0, len(channelSet))
|
|
for _, ch := range channelSet {
|
|
chs = append(chs, ch)
|
|
}
|
|
mm.BoundChannels = chs
|
|
}
|
|
|
|
// 匹配信息
|
|
mm.MatchedModels = names
|
|
mm.MatchedCount = len(names)
|
|
}
|
|
}
|