This commit refactors the logging mechanism across the application by replacing direct logger calls with a centralized logging approach using the `common` package. Key changes include: - Replaced instances of `logger.SysLog` and `logger.FatalLog` with `common.SysLog` and `common.FatalLog` for consistent logging practices. - Updated resource initialization error handling to utilize the new logging structure, enhancing maintainability and readability. - Minor adjustments to improve code clarity and organization throughout various modules. This change aims to streamline logging and improve the overall architecture of the codebase.
310 lines
9.2 KiB
Go
310 lines
9.2 KiB
Go
package model
|
||
|
||
import (
|
||
"encoding/json"
|
||
"fmt"
|
||
"strings"
|
||
|
||
"one-api/common"
|
||
"one-api/constant"
|
||
"one-api/setting/ratio_setting"
|
||
"one-api/types"
|
||
"sync"
|
||
"time"
|
||
)
|
||
|
||
type Pricing struct {
|
||
ModelName string `json:"model_name"`
|
||
Description string `json:"description,omitempty"`
|
||
Icon string `json:"icon,omitempty"`
|
||
Tags string `json:"tags,omitempty"`
|
||
VendorID int `json:"vendor_id,omitempty"`
|
||
QuotaType int `json:"quota_type"`
|
||
ModelRatio float64 `json:"model_ratio"`
|
||
ModelPrice float64 `json:"model_price"`
|
||
OwnerBy string `json:"owner_by"`
|
||
CompletionRatio float64 `json:"completion_ratio"`
|
||
EnableGroup []string `json:"enable_groups"`
|
||
SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
|
||
}
|
||
|
||
type PricingVendor struct {
|
||
ID int `json:"id"`
|
||
Name string `json:"name"`
|
||
Description string `json:"description,omitempty"`
|
||
Icon string `json:"icon,omitempty"`
|
||
}
|
||
|
||
var (
|
||
pricingMap []Pricing
|
||
vendorsList []PricingVendor
|
||
supportedEndpointMap map[string]common.EndpointInfo
|
||
lastGetPricingTime time.Time
|
||
updatePricingLock sync.Mutex
|
||
|
||
// 缓存映射:模型名 -> 启用分组 / 计费类型
|
||
modelEnableGroups = make(map[string][]string)
|
||
modelQuotaTypeMap = make(map[string]int)
|
||
modelEnableGroupsLock = sync.RWMutex{}
|
||
)
|
||
|
||
var (
|
||
modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
|
||
modelSupportEndpointsLock = sync.RWMutex{}
|
||
)
|
||
|
||
func GetPricing() []Pricing {
|
||
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
|
||
updatePricingLock.Lock()
|
||
defer updatePricingLock.Unlock()
|
||
// Double check after acquiring the lock
|
||
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
|
||
modelSupportEndpointsLock.Lock()
|
||
defer modelSupportEndpointsLock.Unlock()
|
||
updatePricing()
|
||
}
|
||
}
|
||
return pricingMap
|
||
}
|
||
|
||
// GetVendors 返回当前定价接口使用到的供应商信息
|
||
func GetVendors() []PricingVendor {
|
||
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
|
||
// 保证先刷新一次
|
||
GetPricing()
|
||
}
|
||
return vendorsList
|
||
}
|
||
|
||
func GetModelSupportEndpointTypes(model string) []constant.EndpointType {
|
||
if model == "" {
|
||
return make([]constant.EndpointType, 0)
|
||
}
|
||
modelSupportEndpointsLock.RLock()
|
||
defer modelSupportEndpointsLock.RUnlock()
|
||
if endpoints, ok := modelSupportEndpointTypes[model]; ok {
|
||
return endpoints
|
||
}
|
||
return make([]constant.EndpointType, 0)
|
||
}
|
||
|
||
func updatePricing() {
|
||
//modelRatios := common.GetModelRatios()
|
||
enableAbilities, err := GetAllEnableAbilityWithChannels()
|
||
if err != nil {
|
||
common.SysLog(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err))
|
||
return
|
||
}
|
||
// 预加载模型元数据与供应商一次,避免循环查询
|
||
var allMeta []Model
|
||
_ = DB.Find(&allMeta).Error
|
||
metaMap := make(map[string]*Model)
|
||
prefixList := make([]*Model, 0)
|
||
suffixList := make([]*Model, 0)
|
||
containsList := make([]*Model, 0)
|
||
for i := range allMeta {
|
||
m := &allMeta[i]
|
||
if m.NameRule == NameRuleExact {
|
||
metaMap[m.ModelName] = m
|
||
} else {
|
||
switch m.NameRule {
|
||
case NameRulePrefix:
|
||
prefixList = append(prefixList, m)
|
||
case NameRuleSuffix:
|
||
suffixList = append(suffixList, m)
|
||
case NameRuleContains:
|
||
containsList = append(containsList, m)
|
||
}
|
||
}
|
||
}
|
||
|
||
// 将非精确规则模型匹配到 metaMap
|
||
for _, m := range prefixList {
|
||
for _, pricingModel := range enableAbilities {
|
||
if strings.HasPrefix(pricingModel.Model, m.ModelName) {
|
||
if _, exists := metaMap[pricingModel.Model]; !exists {
|
||
metaMap[pricingModel.Model] = m
|
||
}
|
||
}
|
||
}
|
||
}
|
||
for _, m := range suffixList {
|
||
for _, pricingModel := range enableAbilities {
|
||
if strings.HasSuffix(pricingModel.Model, m.ModelName) {
|
||
if _, exists := metaMap[pricingModel.Model]; !exists {
|
||
metaMap[pricingModel.Model] = m
|
||
}
|
||
}
|
||
}
|
||
}
|
||
for _, m := range containsList {
|
||
for _, pricingModel := range enableAbilities {
|
||
if strings.Contains(pricingModel.Model, m.ModelName) {
|
||
if _, exists := metaMap[pricingModel.Model]; !exists {
|
||
metaMap[pricingModel.Model] = m
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 预加载供应商
|
||
var vendors []Vendor
|
||
_ = DB.Find(&vendors).Error
|
||
vendorMap := make(map[int]*Vendor)
|
||
for i := range vendors {
|
||
vendorMap[vendors[i].Id] = &vendors[i]
|
||
}
|
||
|
||
// 构建对前端友好的供应商列表
|
||
vendorsList = make([]PricingVendor, 0, len(vendors))
|
||
for _, v := range vendors {
|
||
vendorsList = append(vendorsList, PricingVendor{
|
||
ID: v.Id,
|
||
Name: v.Name,
|
||
Description: v.Description,
|
||
Icon: v.Icon,
|
||
})
|
||
}
|
||
|
||
modelGroupsMap := make(map[string]*types.Set[string])
|
||
|
||
for _, ability := range enableAbilities {
|
||
groups, ok := modelGroupsMap[ability.Model]
|
||
if !ok {
|
||
groups = types.NewSet[string]()
|
||
modelGroupsMap[ability.Model] = groups
|
||
}
|
||
groups.Add(ability.Group)
|
||
}
|
||
|
||
//这里使用切片而不是Set,因为一个模型可能支持多个端点类型,并且第一个端点是优先使用端点
|
||
modelSupportEndpointsStr := make(map[string][]string)
|
||
|
||
// 先根据已有能力填充原生端点
|
||
for _, ability := range enableAbilities {
|
||
endpoints := modelSupportEndpointsStr[ability.Model]
|
||
channelTypes := common.GetEndpointTypesByChannelType(ability.ChannelType, ability.Model)
|
||
for _, channelType := range channelTypes {
|
||
if !common.StringsContains(endpoints, string(channelType)) {
|
||
endpoints = append(endpoints, string(channelType))
|
||
}
|
||
}
|
||
modelSupportEndpointsStr[ability.Model] = endpoints
|
||
}
|
||
|
||
// 再补充模型自定义端点
|
||
for modelName, meta := range metaMap {
|
||
if strings.TrimSpace(meta.Endpoints) == "" {
|
||
continue
|
||
}
|
||
var raw map[string]interface{}
|
||
if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil {
|
||
endpoints := modelSupportEndpointsStr[modelName]
|
||
for k := range raw {
|
||
if !common.StringsContains(endpoints, k) {
|
||
endpoints = append(endpoints, k)
|
||
}
|
||
}
|
||
modelSupportEndpointsStr[modelName] = endpoints
|
||
}
|
||
}
|
||
|
||
modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
|
||
for model, endpoints := range modelSupportEndpointsStr {
|
||
supportedEndpoints := make([]constant.EndpointType, 0)
|
||
for _, endpointStr := range endpoints {
|
||
endpointType := constant.EndpointType(endpointStr)
|
||
supportedEndpoints = append(supportedEndpoints, endpointType)
|
||
}
|
||
modelSupportEndpointTypes[model] = supportedEndpoints
|
||
}
|
||
|
||
// 构建全局 supportedEndpointMap(默认 + 自定义覆盖)
|
||
supportedEndpointMap = make(map[string]common.EndpointInfo)
|
||
// 1. 默认端点
|
||
for _, endpoints := range modelSupportEndpointTypes {
|
||
for _, et := range endpoints {
|
||
if info, ok := common.GetDefaultEndpointInfo(et); ok {
|
||
if _, exists := supportedEndpointMap[string(et)]; !exists {
|
||
supportedEndpointMap[string(et)] = info
|
||
}
|
||
}
|
||
}
|
||
}
|
||
// 2. 自定义端点(models 表)覆盖默认
|
||
for _, meta := range metaMap {
|
||
if strings.TrimSpace(meta.Endpoints) == "" {
|
||
continue
|
||
}
|
||
var raw map[string]interface{}
|
||
if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil {
|
||
for k, v := range raw {
|
||
switch val := v.(type) {
|
||
case string:
|
||
supportedEndpointMap[k] = common.EndpointInfo{Path: val, Method: "POST"}
|
||
case map[string]interface{}:
|
||
ep := common.EndpointInfo{Method: "POST"}
|
||
if p, ok := val["path"].(string); ok {
|
||
ep.Path = p
|
||
}
|
||
if m, ok := val["method"].(string); ok {
|
||
ep.Method = strings.ToUpper(m)
|
||
}
|
||
supportedEndpointMap[k] = ep
|
||
default:
|
||
// ignore unsupported types
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
pricingMap = make([]Pricing, 0)
|
||
for model, groups := range modelGroupsMap {
|
||
pricing := Pricing{
|
||
ModelName: model,
|
||
EnableGroup: groups.Items(),
|
||
SupportedEndpointTypes: modelSupportEndpointTypes[model],
|
||
}
|
||
|
||
// 补充模型元数据(描述、标签、供应商、状态)
|
||
if meta, ok := metaMap[model]; ok {
|
||
// 若模型被禁用(status!=1),则直接跳过,不返回给前端
|
||
if meta.Status != 1 {
|
||
continue
|
||
}
|
||
pricing.Description = meta.Description
|
||
pricing.Icon = meta.Icon
|
||
pricing.Tags = meta.Tags
|
||
pricing.VendorID = meta.VendorID
|
||
}
|
||
modelPrice, findPrice := ratio_setting.GetModelPrice(model, false)
|
||
if findPrice {
|
||
pricing.ModelPrice = modelPrice
|
||
pricing.QuotaType = 1
|
||
} else {
|
||
modelRatio, _, _ := ratio_setting.GetModelRatio(model)
|
||
pricing.ModelRatio = modelRatio
|
||
pricing.CompletionRatio = ratio_setting.GetCompletionRatio(model)
|
||
pricing.QuotaType = 0
|
||
}
|
||
pricingMap = append(pricingMap, pricing)
|
||
}
|
||
|
||
// 刷新缓存映射,供高并发快速查询
|
||
modelEnableGroupsLock.Lock()
|
||
modelEnableGroups = make(map[string][]string)
|
||
modelQuotaTypeMap = make(map[string]int)
|
||
for _, p := range pricingMap {
|
||
modelEnableGroups[p.ModelName] = p.EnableGroup
|
||
modelQuotaTypeMap[p.ModelName] = p.QuotaType
|
||
}
|
||
modelEnableGroupsLock.Unlock()
|
||
|
||
lastGetPricingTime = time.Now()
|
||
}
|
||
|
||
// GetSupportedEndpointMap 返回全局端点到路径的映射
|
||
func GetSupportedEndpointMap() map[string]common.EndpointInfo {
|
||
return supportedEndpointMap
|
||
}
|