Backend: - Model: Add `icon` field to `model.Model` (gorm: varchar(128)); auto-migrated via GORM. - Pricing API: Extend `model.Pricing` with `icon` and populate from model meta in `GetPricing()`. Frontend: - EditModelModal: Add `icon` input (with @lobehub/icons helper link); wire into init/load/submit flows. - ModelHeader / PricingCardView: Prefer rendering `model.icon`; fallback to `vendor_icon`; final fallback to initials avatar. - Models table: Add leading “Icon” column, rendering `model.icon` or `vendor` icon via `getLobeHubIcon`. Notes: - Backward-compatible. Existing data without `icon` remain unaffected. - No manual SQL needed; column is added by AutoMigrate. Affected files: - model/model_meta.go - model/pricing.go - web/src/components/table/models/modals/EditModelModal.jsx - web/src/components/table/model-pricing/modal/components/ModelHeader.jsx - web/src/components/table/model-pricing/view/card/PricingCardView.jsx - web/src/components/table/models/ModelsColumnDefs.js
310 lines
10 KiB
Go
310 lines
10 KiB
Go
package model
|
||
|
||
import (
|
||
"encoding/json"
|
||
"fmt"
|
||
"strings"
|
||
|
||
"one-api/common"
|
||
"one-api/constant"
|
||
"one-api/setting/ratio_setting"
|
||
"one-api/types"
|
||
"sync"
|
||
"time"
|
||
)
|
||
|
||
type Pricing struct {
|
||
ModelName string `json:"model_name"`
|
||
Description string `json:"description,omitempty"`
|
||
Icon string `json:"icon,omitempty"`
|
||
Tags string `json:"tags,omitempty"`
|
||
VendorID int `json:"vendor_id,omitempty"`
|
||
QuotaType int `json:"quota_type"`
|
||
ModelRatio float64 `json:"model_ratio"`
|
||
ModelPrice float64 `json:"model_price"`
|
||
OwnerBy string `json:"owner_by"`
|
||
CompletionRatio float64 `json:"completion_ratio"`
|
||
EnableGroup []string `json:"enable_groups"`
|
||
SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
|
||
}
|
||
|
||
type PricingVendor struct {
|
||
ID int `json:"id"`
|
||
Name string `json:"name"`
|
||
Description string `json:"description,omitempty"`
|
||
Icon string `json:"icon,omitempty"`
|
||
}
|
||
|
||
var (
|
||
pricingMap []Pricing
|
||
vendorsList []PricingVendor
|
||
supportedEndpointMap map[string]common.EndpointInfo
|
||
lastGetPricingTime time.Time
|
||
updatePricingLock sync.Mutex
|
||
|
||
// 缓存映射:模型名 -> 启用分组 / 计费类型
|
||
modelEnableGroups = make(map[string][]string)
|
||
modelQuotaTypeMap = make(map[string]int)
|
||
modelEnableGroupsLock = sync.RWMutex{}
|
||
)
|
||
|
||
var (
|
||
modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
|
||
modelSupportEndpointsLock = sync.RWMutex{}
|
||
)
|
||
|
||
func GetPricing() []Pricing {
|
||
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
|
||
updatePricingLock.Lock()
|
||
defer updatePricingLock.Unlock()
|
||
// Double check after acquiring the lock
|
||
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
|
||
modelSupportEndpointsLock.Lock()
|
||
defer modelSupportEndpointsLock.Unlock()
|
||
updatePricing()
|
||
}
|
||
}
|
||
return pricingMap
|
||
}
|
||
|
||
// GetVendors 返回当前定价接口使用到的供应商信息
|
||
func GetVendors() []PricingVendor {
|
||
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
|
||
// 保证先刷新一次
|
||
GetPricing()
|
||
}
|
||
return vendorsList
|
||
}
|
||
|
||
func GetModelSupportEndpointTypes(model string) []constant.EndpointType {
|
||
if model == "" {
|
||
return make([]constant.EndpointType, 0)
|
||
}
|
||
modelSupportEndpointsLock.RLock()
|
||
defer modelSupportEndpointsLock.RUnlock()
|
||
if endpoints, ok := modelSupportEndpointTypes[model]; ok {
|
||
return endpoints
|
||
}
|
||
return make([]constant.EndpointType, 0)
|
||
}
|
||
|
||
func updatePricing() {
|
||
//modelRatios := common.GetModelRatios()
|
||
enableAbilities, err := GetAllEnableAbilityWithChannels()
|
||
if err != nil {
|
||
common.SysError(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err))
|
||
return
|
||
}
|
||
// 预加载模型元数据与供应商一次,避免循环查询
|
||
var allMeta []Model
|
||
_ = DB.Find(&allMeta).Error
|
||
metaMap := make(map[string]*Model)
|
||
prefixList := make([]*Model, 0)
|
||
suffixList := make([]*Model, 0)
|
||
containsList := make([]*Model, 0)
|
||
for i := range allMeta {
|
||
m := &allMeta[i]
|
||
if m.NameRule == NameRuleExact {
|
||
metaMap[m.ModelName] = m
|
||
} else {
|
||
switch m.NameRule {
|
||
case NameRulePrefix:
|
||
prefixList = append(prefixList, m)
|
||
case NameRuleSuffix:
|
||
suffixList = append(suffixList, m)
|
||
case NameRuleContains:
|
||
containsList = append(containsList, m)
|
||
}
|
||
}
|
||
}
|
||
|
||
// 将非精确规则模型匹配到 metaMap
|
||
for _, m := range prefixList {
|
||
for _, pricingModel := range enableAbilities {
|
||
if strings.HasPrefix(pricingModel.Model, m.ModelName) {
|
||
if _, exists := metaMap[pricingModel.Model]; !exists {
|
||
metaMap[pricingModel.Model] = m
|
||
}
|
||
}
|
||
}
|
||
}
|
||
for _, m := range suffixList {
|
||
for _, pricingModel := range enableAbilities {
|
||
if strings.HasSuffix(pricingModel.Model, m.ModelName) {
|
||
if _, exists := metaMap[pricingModel.Model]; !exists {
|
||
metaMap[pricingModel.Model] = m
|
||
}
|
||
}
|
||
}
|
||
}
|
||
for _, m := range containsList {
|
||
for _, pricingModel := range enableAbilities {
|
||
if strings.Contains(pricingModel.Model, m.ModelName) {
|
||
if _, exists := metaMap[pricingModel.Model]; !exists {
|
||
metaMap[pricingModel.Model] = m
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 预加载供应商
|
||
var vendors []Vendor
|
||
_ = DB.Find(&vendors).Error
|
||
vendorMap := make(map[int]*Vendor)
|
||
for i := range vendors {
|
||
vendorMap[vendors[i].Id] = &vendors[i]
|
||
}
|
||
|
||
// 构建对前端友好的供应商列表
|
||
vendorsList = make([]PricingVendor, 0, len(vendors))
|
||
for _, v := range vendors {
|
||
vendorsList = append(vendorsList, PricingVendor{
|
||
ID: v.Id,
|
||
Name: v.Name,
|
||
Description: v.Description,
|
||
Icon: v.Icon,
|
||
})
|
||
}
|
||
|
||
modelGroupsMap := make(map[string]*types.Set[string])
|
||
|
||
for _, ability := range enableAbilities {
|
||
groups, ok := modelGroupsMap[ability.Model]
|
||
if !ok {
|
||
groups = types.NewSet[string]()
|
||
modelGroupsMap[ability.Model] = groups
|
||
}
|
||
groups.Add(ability.Group)
|
||
}
|
||
|
||
//这里使用切片而不是Set,因为一个模型可能支持多个端点类型,并且第一个端点是优先使用端点
|
||
modelSupportEndpointsStr := make(map[string][]string)
|
||
|
||
// 先根据已有能力填充原生端点
|
||
for _, ability := range enableAbilities {
|
||
endpoints := modelSupportEndpointsStr[ability.Model]
|
||
channelTypes := common.GetEndpointTypesByChannelType(ability.ChannelType, ability.Model)
|
||
for _, channelType := range channelTypes {
|
||
if !common.StringsContains(endpoints, string(channelType)) {
|
||
endpoints = append(endpoints, string(channelType))
|
||
}
|
||
}
|
||
modelSupportEndpointsStr[ability.Model] = endpoints
|
||
}
|
||
|
||
// 再补充模型自定义端点
|
||
for modelName, meta := range metaMap {
|
||
if strings.TrimSpace(meta.Endpoints) == "" {
|
||
continue
|
||
}
|
||
var raw map[string]interface{}
|
||
if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil {
|
||
endpoints := modelSupportEndpointsStr[modelName]
|
||
for k := range raw {
|
||
if !common.StringsContains(endpoints, k) {
|
||
endpoints = append(endpoints, k)
|
||
}
|
||
}
|
||
modelSupportEndpointsStr[modelName] = endpoints
|
||
}
|
||
}
|
||
|
||
modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
|
||
for model, endpoints := range modelSupportEndpointsStr {
|
||
supportedEndpoints := make([]constant.EndpointType, 0)
|
||
for _, endpointStr := range endpoints {
|
||
endpointType := constant.EndpointType(endpointStr)
|
||
supportedEndpoints = append(supportedEndpoints, endpointType)
|
||
}
|
||
modelSupportEndpointTypes[model] = supportedEndpoints
|
||
}
|
||
|
||
// 构建全局 supportedEndpointMap(默认 + 自定义覆盖)
|
||
supportedEndpointMap = make(map[string]common.EndpointInfo)
|
||
// 1. 默认端点
|
||
for _, endpoints := range modelSupportEndpointTypes {
|
||
for _, et := range endpoints {
|
||
if info, ok := common.GetDefaultEndpointInfo(et); ok {
|
||
if _, exists := supportedEndpointMap[string(et)]; !exists {
|
||
supportedEndpointMap[string(et)] = info
|
||
}
|
||
}
|
||
}
|
||
}
|
||
// 2. 自定义端点(models 表)覆盖默认
|
||
for _, meta := range metaMap {
|
||
if strings.TrimSpace(meta.Endpoints) == "" {
|
||
continue
|
||
}
|
||
var raw map[string]interface{}
|
||
if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil {
|
||
for k, v := range raw {
|
||
switch val := v.(type) {
|
||
case string:
|
||
supportedEndpointMap[k] = common.EndpointInfo{Path: val, Method: "POST"}
|
||
case map[string]interface{}:
|
||
ep := common.EndpointInfo{Method: "POST"}
|
||
if p, ok := val["path"].(string); ok {
|
||
ep.Path = p
|
||
}
|
||
if m, ok := val["method"].(string); ok {
|
||
ep.Method = strings.ToUpper(m)
|
||
}
|
||
supportedEndpointMap[k] = ep
|
||
default:
|
||
// ignore unsupported types
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
pricingMap = make([]Pricing, 0)
|
||
for model, groups := range modelGroupsMap {
|
||
pricing := Pricing{
|
||
ModelName: model,
|
||
EnableGroup: groups.Items(),
|
||
SupportedEndpointTypes: modelSupportEndpointTypes[model],
|
||
}
|
||
|
||
// 补充模型元数据(描述、标签、供应商、状态)
|
||
if meta, ok := metaMap[model]; ok {
|
||
// 若模型被禁用(status!=1),则直接跳过,不返回给前端
|
||
if meta.Status != 1 {
|
||
continue
|
||
}
|
||
pricing.Description = meta.Description
|
||
pricing.Icon = meta.Icon
|
||
pricing.Tags = meta.Tags
|
||
pricing.VendorID = meta.VendorID
|
||
}
|
||
modelPrice, findPrice := ratio_setting.GetModelPrice(model, false)
|
||
if findPrice {
|
||
pricing.ModelPrice = modelPrice
|
||
pricing.QuotaType = 1
|
||
} else {
|
||
modelRatio, _, _ := ratio_setting.GetModelRatio(model)
|
||
pricing.ModelRatio = modelRatio
|
||
pricing.CompletionRatio = ratio_setting.GetCompletionRatio(model)
|
||
pricing.QuotaType = 0
|
||
}
|
||
pricingMap = append(pricingMap, pricing)
|
||
}
|
||
|
||
// 刷新缓存映射,供高并发快速查询
|
||
modelEnableGroupsLock.Lock()
|
||
modelEnableGroups = make(map[string][]string)
|
||
modelQuotaTypeMap = make(map[string]int)
|
||
for _, p := range pricingMap {
|
||
modelEnableGroups[p.ModelName] = p.EnableGroup
|
||
modelQuotaTypeMap[p.ModelName] = p.QuotaType
|
||
}
|
||
modelEnableGroupsLock.Unlock()
|
||
|
||
lastGetPricingTime = time.Now()
|
||
}
|
||
|
||
// GetSupportedEndpointMap 返回全局端点到路径的映射
|
||
func GetSupportedEndpointMap() map[string]common.EndpointInfo {
|
||
return supportedEndpointMap
|
||
}
|