This commit refactors the logging mechanism across the application by replacing direct logger calls with a centralized logging approach using the `common` package. Key changes include: - Replaced instances of `logger.SysLog` and `logger.FatalLog` with `common.SysLog` and `common.FatalLog` for consistent logging practices. - Updated resource initialization error handling to utilize the new logging structure, enhancing maintainability and readability. - Minor adjustments to improve code clarity and organization throughout various modules. This change aims to streamline logging and improve the overall architecture of the codebase.
285 lines
7.9 KiB
Go
285 lines
7.9 KiB
Go
package model
|
||
|
||
import (
|
||
"errors"
|
||
"fmt"
|
||
"math/rand"
|
||
"one-api/common"
|
||
"one-api/constant"
|
||
"one-api/setting"
|
||
"one-api/setting/ratio_setting"
|
||
"sort"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
)
|
||
|
||
var group2model2channels map[string]map[string][]int // enabled channel
|
||
var channelsIDM map[int]*Channel // all channels include disabled
|
||
var channelSyncLock sync.RWMutex
|
||
|
||
func InitChannelCache() {
|
||
if !common.MemoryCacheEnabled {
|
||
return
|
||
}
|
||
newChannelId2channel := make(map[int]*Channel)
|
||
var channels []*Channel
|
||
DB.Find(&channels)
|
||
for _, channel := range channels {
|
||
newChannelId2channel[channel.Id] = channel
|
||
}
|
||
var abilities []*Ability
|
||
DB.Find(&abilities)
|
||
groups := make(map[string]bool)
|
||
for _, ability := range abilities {
|
||
groups[ability.Group] = true
|
||
}
|
||
newGroup2model2channels := make(map[string]map[string][]int)
|
||
for group := range groups {
|
||
newGroup2model2channels[group] = make(map[string][]int)
|
||
}
|
||
for _, channel := range channels {
|
||
if channel.Status != common.ChannelStatusEnabled {
|
||
continue // skip disabled channels
|
||
}
|
||
groups := strings.Split(channel.Group, ",")
|
||
for _, group := range groups {
|
||
models := strings.Split(channel.Models, ",")
|
||
for _, model := range models {
|
||
if _, ok := newGroup2model2channels[group][model]; !ok {
|
||
newGroup2model2channels[group][model] = make([]int, 0)
|
||
}
|
||
newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel.Id)
|
||
}
|
||
}
|
||
}
|
||
|
||
// sort by priority
|
||
for group, model2channels := range newGroup2model2channels {
|
||
for model, channels := range model2channels {
|
||
sort.Slice(channels, func(i, j int) bool {
|
||
return newChannelId2channel[channels[i]].GetPriority() > newChannelId2channel[channels[j]].GetPriority()
|
||
})
|
||
newGroup2model2channels[group][model] = channels
|
||
}
|
||
}
|
||
|
||
channelSyncLock.Lock()
|
||
group2model2channels = newGroup2model2channels
|
||
//channelsIDM = newChannelId2channel
|
||
for i, channel := range newChannelId2channel {
|
||
if channel.ChannelInfo.IsMultiKey {
|
||
channel.Keys = channel.GetKeys()
|
||
if channel.ChannelInfo.MultiKeyMode == constant.MultiKeyModePolling {
|
||
if oldChannel, ok := channelsIDM[i]; ok {
|
||
// 存在旧的渠道,如果是多key且轮询,保留轮询索引信息
|
||
if oldChannel.ChannelInfo.IsMultiKey && oldChannel.ChannelInfo.MultiKeyMode == constant.MultiKeyModePolling {
|
||
channel.ChannelInfo.MultiKeyPollingIndex = oldChannel.ChannelInfo.MultiKeyPollingIndex
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
channelsIDM = newChannelId2channel
|
||
channelSyncLock.Unlock()
|
||
common.SysLog("channels synced from database")
|
||
}
|
||
|
||
func SyncChannelCache(frequency int) {
|
||
for {
|
||
time.Sleep(time.Duration(frequency) * time.Second)
|
||
common.SysLog("syncing channels from database")
|
||
InitChannelCache()
|
||
}
|
||
}
|
||
|
||
func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, model string, retry int) (*Channel, string, error) {
|
||
var channel *Channel
|
||
var err error
|
||
selectGroup := group
|
||
if group == "auto" {
|
||
if len(setting.AutoGroups) == 0 {
|
||
return nil, selectGroup, errors.New("auto groups is not enabled")
|
||
}
|
||
for _, autoGroup := range setting.AutoGroups {
|
||
if common.DebugEnabled {
|
||
println("autoGroup:", autoGroup)
|
||
}
|
||
channel, _ = getRandomSatisfiedChannel(autoGroup, model, retry)
|
||
if channel == nil {
|
||
continue
|
||
} else {
|
||
c.Set("auto_group", autoGroup)
|
||
selectGroup = autoGroup
|
||
if common.DebugEnabled {
|
||
println("selectGroup:", selectGroup)
|
||
}
|
||
break
|
||
}
|
||
}
|
||
} else {
|
||
channel, err = getRandomSatisfiedChannel(group, model, retry)
|
||
if err != nil {
|
||
return nil, group, err
|
||
}
|
||
}
|
||
return channel, selectGroup, nil
|
||
}
|
||
|
||
func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
|
||
// if memory cache is disabled, get channel directly from database
|
||
if !common.MemoryCacheEnabled {
|
||
return GetRandomSatisfiedChannel(group, model, retry)
|
||
}
|
||
|
||
channelSyncLock.RLock()
|
||
defer channelSyncLock.RUnlock()
|
||
|
||
// First, try to find channels with the exact model name.
|
||
channels := group2model2channels[group][model]
|
||
|
||
// If no channels found, try to find channels with the normalized model name.
|
||
if len(channels) == 0 {
|
||
normalizedModel := ratio_setting.FormatMatchingModelName(model)
|
||
channels = group2model2channels[group][normalizedModel]
|
||
}
|
||
|
||
if len(channels) == 0 {
|
||
return nil, nil
|
||
}
|
||
|
||
if len(channels) == 1 {
|
||
if channel, ok := channelsIDM[channels[0]]; ok {
|
||
return channel, nil
|
||
}
|
||
return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channels[0])
|
||
}
|
||
|
||
uniquePriorities := make(map[int]bool)
|
||
for _, channelId := range channels {
|
||
if channel, ok := channelsIDM[channelId]; ok {
|
||
uniquePriorities[int(channel.GetPriority())] = true
|
||
} else {
|
||
return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channelId)
|
||
}
|
||
}
|
||
var sortedUniquePriorities []int
|
||
for priority := range uniquePriorities {
|
||
sortedUniquePriorities = append(sortedUniquePriorities, priority)
|
||
}
|
||
sort.Sort(sort.Reverse(sort.IntSlice(sortedUniquePriorities)))
|
||
|
||
if retry >= len(uniquePriorities) {
|
||
retry = len(uniquePriorities) - 1
|
||
}
|
||
targetPriority := int64(sortedUniquePriorities[retry])
|
||
|
||
// get the priority for the given retry number
|
||
var targetChannels []*Channel
|
||
for _, channelId := range channels {
|
||
if channel, ok := channelsIDM[channelId]; ok {
|
||
if channel.GetPriority() == targetPriority {
|
||
targetChannels = append(targetChannels, channel)
|
||
}
|
||
} else {
|
||
return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channelId)
|
||
}
|
||
}
|
||
|
||
// 平滑系数
|
||
smoothingFactor := 10
|
||
// Calculate the total weight of all channels up to endIdx
|
||
totalWeight := 0
|
||
for _, channel := range targetChannels {
|
||
totalWeight += channel.GetWeight() + smoothingFactor
|
||
}
|
||
// Generate a random value in the range [0, totalWeight)
|
||
randomWeight := rand.Intn(totalWeight)
|
||
|
||
// Find a channel based on its weight
|
||
for _, channel := range targetChannels {
|
||
randomWeight -= channel.GetWeight() + smoothingFactor
|
||
if randomWeight < 0 {
|
||
return channel, nil
|
||
}
|
||
}
|
||
// return null if no channel is not found
|
||
return nil, errors.New("channel not found")
|
||
}
|
||
|
||
func CacheGetChannel(id int) (*Channel, error) {
|
||
if !common.MemoryCacheEnabled {
|
||
return GetChannelById(id, true)
|
||
}
|
||
channelSyncLock.RLock()
|
||
defer channelSyncLock.RUnlock()
|
||
|
||
c, ok := channelsIDM[id]
|
||
if !ok {
|
||
return nil, fmt.Errorf("渠道# %d,已不存在", id)
|
||
}
|
||
return c, nil
|
||
}
|
||
|
||
func CacheGetChannelInfo(id int) (*ChannelInfo, error) {
|
||
if !common.MemoryCacheEnabled {
|
||
channel, err := GetChannelById(id, true)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &channel.ChannelInfo, nil
|
||
}
|
||
channelSyncLock.RLock()
|
||
defer channelSyncLock.RUnlock()
|
||
|
||
c, ok := channelsIDM[id]
|
||
if !ok {
|
||
return nil, fmt.Errorf("渠道# %d,已不存在", id)
|
||
}
|
||
return &c.ChannelInfo, nil
|
||
}
|
||
|
||
func CacheUpdateChannelStatus(id int, status int) {
|
||
if !common.MemoryCacheEnabled {
|
||
return
|
||
}
|
||
channelSyncLock.Lock()
|
||
defer channelSyncLock.Unlock()
|
||
if channel, ok := channelsIDM[id]; ok {
|
||
channel.Status = status
|
||
}
|
||
if status != common.ChannelStatusEnabled {
|
||
// delete the channel from group2model2channels
|
||
for group, model2channels := range group2model2channels {
|
||
for model, channels := range model2channels {
|
||
for i, channelId := range channels {
|
||
if channelId == id {
|
||
// remove the channel from the slice
|
||
group2model2channels[group][model] = append(channels[:i], channels[i+1:]...)
|
||
break
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
func CacheUpdateChannel(channel *Channel) {
|
||
if !common.MemoryCacheEnabled {
|
||
return
|
||
}
|
||
channelSyncLock.Lock()
|
||
defer channelSyncLock.Unlock()
|
||
if channel == nil {
|
||
return
|
||
}
|
||
|
||
println("CacheUpdateChannel:", channel.Id, channel.Name, channel.Status, channel.ChannelInfo.MultiKeyPollingIndex)
|
||
|
||
println("before:", channelsIDM[channel.Id].ChannelInfo.MultiKeyPollingIndex)
|
||
channelsIDM[channel.Id] = channel
|
||
println("after :", channelsIDM[channel.Id].ChannelInfo.MultiKeyPollingIndex)
|
||
}
|