Summary
-------
1. Pricing generation
• `model/pricing.go`: skip any model whose `status != 1` when building
`pricingMap`, ensuring disabled models are never returned to the
front-end.
2. Cache refresh placement
• `controller/model_meta.go`
– Removed `model.RefreshPricing()` from pure read handlers
(`GetAllModelsMeta`, `SearchModelsMeta`).
– Kept refresh only in mutating handlers
(`Create`, `Update`, `Delete`), guaranteeing data is updated
immediately after an admin change while avoiding redundant work
on every read.
Result
------
Front-end no longer receives information about disabled models, and
pricing cache refreshes occur exactly when model data is modified,
improving efficiency and consistency.
247 lines
7.6 KiB
Go
247 lines
7.6 KiB
Go
package model
|
||
|
||
import (
|
||
"fmt"
|
||
"strings"
|
||
"one-api/common"
|
||
"one-api/constant"
|
||
"one-api/setting/ratio_setting"
|
||
"one-api/types"
|
||
"sync"
|
||
"time"
|
||
)
|
||
|
||
type Pricing struct {
|
||
ModelName string `json:"model_name"`
|
||
Description string `json:"description,omitempty"`
|
||
Tags string `json:"tags,omitempty"`
|
||
VendorID int `json:"vendor_id,omitempty"`
|
||
QuotaType int `json:"quota_type"`
|
||
ModelRatio float64 `json:"model_ratio"`
|
||
ModelPrice float64 `json:"model_price"`
|
||
OwnerBy string `json:"owner_by"`
|
||
CompletionRatio float64 `json:"completion_ratio"`
|
||
EnableGroup []string `json:"enable_groups"`
|
||
SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
|
||
}
|
||
|
||
type PricingVendor struct {
|
||
ID int `json:"id"`
|
||
Name string `json:"name"`
|
||
Description string `json:"description,omitempty"`
|
||
Icon string `json:"icon,omitempty"`
|
||
}
|
||
|
||
var (
|
||
pricingMap []Pricing
|
||
vendorsList []PricingVendor
|
||
lastGetPricingTime time.Time
|
||
updatePricingLock sync.Mutex
|
||
|
||
// 缓存映射:模型名 -> 启用分组 / 计费类型
|
||
modelEnableGroups = make(map[string][]string)
|
||
modelQuotaTypeMap = make(map[string]int)
|
||
modelEnableGroupsLock = sync.RWMutex{}
|
||
)
|
||
|
||
var (
|
||
modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
|
||
modelSupportEndpointsLock = sync.RWMutex{}
|
||
)
|
||
|
||
func GetPricing() []Pricing {
|
||
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
|
||
updatePricingLock.Lock()
|
||
defer updatePricingLock.Unlock()
|
||
// Double check after acquiring the lock
|
||
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
|
||
modelSupportEndpointsLock.Lock()
|
||
defer modelSupportEndpointsLock.Unlock()
|
||
updatePricing()
|
||
}
|
||
}
|
||
return pricingMap
|
||
}
|
||
|
||
// GetVendors 返回当前定价接口使用到的供应商信息
|
||
func GetVendors() []PricingVendor {
|
||
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
|
||
// 保证先刷新一次
|
||
GetPricing()
|
||
}
|
||
return vendorsList
|
||
}
|
||
|
||
func GetModelSupportEndpointTypes(model string) []constant.EndpointType {
|
||
if model == "" {
|
||
return make([]constant.EndpointType, 0)
|
||
}
|
||
modelSupportEndpointsLock.RLock()
|
||
defer modelSupportEndpointsLock.RUnlock()
|
||
if endpoints, ok := modelSupportEndpointTypes[model]; ok {
|
||
return endpoints
|
||
}
|
||
return make([]constant.EndpointType, 0)
|
||
}
|
||
|
||
func updatePricing() {
|
||
//modelRatios := common.GetModelRatios()
|
||
enableAbilities, err := GetAllEnableAbilityWithChannels()
|
||
if err != nil {
|
||
common.SysError(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err))
|
||
return
|
||
}
|
||
// 预加载模型元数据与供应商一次,避免循环查询
|
||
var allMeta []Model
|
||
_ = DB.Find(&allMeta).Error
|
||
metaMap := make(map[string]*Model)
|
||
prefixList := make([]*Model, 0)
|
||
suffixList := make([]*Model, 0)
|
||
containsList := make([]*Model, 0)
|
||
for i := range allMeta {
|
||
m := &allMeta[i]
|
||
if m.NameRule == NameRuleExact {
|
||
metaMap[m.ModelName] = m
|
||
} else {
|
||
switch m.NameRule {
|
||
case NameRulePrefix:
|
||
prefixList = append(prefixList, m)
|
||
case NameRuleSuffix:
|
||
suffixList = append(suffixList, m)
|
||
case NameRuleContains:
|
||
containsList = append(containsList, m)
|
||
}
|
||
}
|
||
}
|
||
|
||
// 将非精确规则模型匹配到 metaMap
|
||
for _, m := range prefixList {
|
||
for _, pricingModel := range enableAbilities {
|
||
if strings.HasPrefix(pricingModel.Model, m.ModelName) {
|
||
if _, exists := metaMap[pricingModel.Model]; !exists {
|
||
metaMap[pricingModel.Model] = m
|
||
}
|
||
}
|
||
}
|
||
}
|
||
for _, m := range suffixList {
|
||
for _, pricingModel := range enableAbilities {
|
||
if strings.HasSuffix(pricingModel.Model, m.ModelName) {
|
||
if _, exists := metaMap[pricingModel.Model]; !exists {
|
||
metaMap[pricingModel.Model] = m
|
||
}
|
||
}
|
||
}
|
||
}
|
||
for _, m := range containsList {
|
||
for _, pricingModel := range enableAbilities {
|
||
if strings.Contains(pricingModel.Model, m.ModelName) {
|
||
if _, exists := metaMap[pricingModel.Model]; !exists {
|
||
metaMap[pricingModel.Model] = m
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 预加载供应商
|
||
var vendors []Vendor
|
||
_ = DB.Find(&vendors).Error
|
||
vendorMap := make(map[int]*Vendor)
|
||
for i := range vendors {
|
||
vendorMap[vendors[i].Id] = &vendors[i]
|
||
}
|
||
|
||
// 构建对前端友好的供应商列表
|
||
vendorsList = make([]PricingVendor, 0, len(vendors))
|
||
for _, v := range vendors {
|
||
vendorsList = append(vendorsList, PricingVendor{
|
||
ID: v.Id,
|
||
Name: v.Name,
|
||
Description: v.Description,
|
||
Icon: v.Icon,
|
||
})
|
||
}
|
||
|
||
modelGroupsMap := make(map[string]*types.Set[string])
|
||
|
||
for _, ability := range enableAbilities {
|
||
groups, ok := modelGroupsMap[ability.Model]
|
||
if !ok {
|
||
groups = types.NewSet[string]()
|
||
modelGroupsMap[ability.Model] = groups
|
||
}
|
||
groups.Add(ability.Group)
|
||
}
|
||
|
||
//这里使用切片而不是Set,因为一个模型可能支持多个端点类型,并且第一个端点是优先使用端点
|
||
modelSupportEndpointsStr := make(map[string][]string)
|
||
|
||
for _, ability := range enableAbilities {
|
||
endpoints, ok := modelSupportEndpointsStr[ability.Model]
|
||
if !ok {
|
||
endpoints = make([]string, 0)
|
||
modelSupportEndpointsStr[ability.Model] = endpoints
|
||
}
|
||
channelTypes := common.GetEndpointTypesByChannelType(ability.ChannelType, ability.Model)
|
||
for _, channelType := range channelTypes {
|
||
if !common.StringsContains(endpoints, string(channelType)) {
|
||
endpoints = append(endpoints, string(channelType))
|
||
}
|
||
}
|
||
modelSupportEndpointsStr[ability.Model] = endpoints
|
||
}
|
||
|
||
modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
|
||
for model, endpoints := range modelSupportEndpointsStr {
|
||
supportedEndpoints := make([]constant.EndpointType, 0)
|
||
for _, endpointStr := range endpoints {
|
||
endpointType := constant.EndpointType(endpointStr)
|
||
supportedEndpoints = append(supportedEndpoints, endpointType)
|
||
}
|
||
modelSupportEndpointTypes[model] = supportedEndpoints
|
||
}
|
||
|
||
pricingMap = make([]Pricing, 0)
|
||
for model, groups := range modelGroupsMap {
|
||
pricing := Pricing{
|
||
ModelName: model,
|
||
EnableGroup: groups.Items(),
|
||
SupportedEndpointTypes: modelSupportEndpointTypes[model],
|
||
}
|
||
|
||
// 补充模型元数据(描述、标签、供应商、状态)
|
||
if meta, ok := metaMap[model]; ok {
|
||
// 若模型被禁用(status!=1),则直接跳过,不返回给前端
|
||
if meta.Status != 1 {
|
||
continue
|
||
}
|
||
pricing.Description = meta.Description
|
||
pricing.Tags = meta.Tags
|
||
pricing.VendorID = meta.VendorID
|
||
}
|
||
modelPrice, findPrice := ratio_setting.GetModelPrice(model, false)
|
||
if findPrice {
|
||
pricing.ModelPrice = modelPrice
|
||
pricing.QuotaType = 1
|
||
} else {
|
||
modelRatio, _, _ := ratio_setting.GetModelRatio(model)
|
||
pricing.ModelRatio = modelRatio
|
||
pricing.CompletionRatio = ratio_setting.GetCompletionRatio(model)
|
||
pricing.QuotaType = 0
|
||
}
|
||
pricingMap = append(pricingMap, pricing)
|
||
}
|
||
|
||
// 刷新缓存映射,供高并发快速查询
|
||
modelEnableGroupsLock.Lock()
|
||
modelEnableGroups = make(map[string][]string)
|
||
modelQuotaTypeMap = make(map[string]int)
|
||
for _, p := range pricingMap {
|
||
modelEnableGroups[p.ModelName] = p.EnableGroup
|
||
modelQuotaTypeMap[p.ModelName] = p.QuotaType
|
||
}
|
||
modelEnableGroupsLock.Unlock()
|
||
|
||
lastGetPricingTime = time.Now()
|
||
}
|