This commit introduces a major architectural refactoring to improve quota management, centralize logging, and streamline the relay handling logic. Key changes: - **Pre-consume Quota:** Implements a new mechanism to check and reserve user quota *before* making the request to the upstream provider. This ensures more accurate quota deduction and prevents users from exceeding their limits due to concurrent requests. - **Unified Relay Handlers:** Refactors the relay logic to use generic handlers (e.g., `ChatHandler`, `ImageHandler`) instead of provider-specific implementations. This significantly reduces code duplication and simplifies adding new channels. - **Centralized Logger:** A new dedicated `logger` package is introduced, and all system logging calls are migrated to use it, moving this responsibility out of the `common` package. - **Code Reorganization:** DTOs are generalized (e.g., `dalle.go` -> `openai_image.go`) and utility code is moved to more appropriate packages (e.g., `common/http.go` -> `service/http.go`) for better code structure.
286 lines
7.9 KiB
Go
286 lines
7.9 KiB
Go
package model
|
||
|
||
import (
|
||
"errors"
|
||
"fmt"
|
||
"math/rand"
|
||
"one-api/common"
|
||
"one-api/constant"
|
||
"one-api/logger"
|
||
"one-api/setting"
|
||
"one-api/setting/ratio_setting"
|
||
"sort"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
)
|
||
|
||
var group2model2channels map[string]map[string][]int // enabled channel
|
||
var channelsIDM map[int]*Channel // all channels include disabled
|
||
var channelSyncLock sync.RWMutex
|
||
|
||
func InitChannelCache() {
|
||
if !common.MemoryCacheEnabled {
|
||
return
|
||
}
|
||
newChannelId2channel := make(map[int]*Channel)
|
||
var channels []*Channel
|
||
DB.Find(&channels)
|
||
for _, channel := range channels {
|
||
newChannelId2channel[channel.Id] = channel
|
||
}
|
||
var abilities []*Ability
|
||
DB.Find(&abilities)
|
||
groups := make(map[string]bool)
|
||
for _, ability := range abilities {
|
||
groups[ability.Group] = true
|
||
}
|
||
newGroup2model2channels := make(map[string]map[string][]int)
|
||
for group := range groups {
|
||
newGroup2model2channels[group] = make(map[string][]int)
|
||
}
|
||
for _, channel := range channels {
|
||
if channel.Status != common.ChannelStatusEnabled {
|
||
continue // skip disabled channels
|
||
}
|
||
groups := strings.Split(channel.Group, ",")
|
||
for _, group := range groups {
|
||
models := strings.Split(channel.Models, ",")
|
||
for _, model := range models {
|
||
if _, ok := newGroup2model2channels[group][model]; !ok {
|
||
newGroup2model2channels[group][model] = make([]int, 0)
|
||
}
|
||
newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel.Id)
|
||
}
|
||
}
|
||
}
|
||
|
||
// sort by priority
|
||
for group, model2channels := range newGroup2model2channels {
|
||
for model, channels := range model2channels {
|
||
sort.Slice(channels, func(i, j int) bool {
|
||
return newChannelId2channel[channels[i]].GetPriority() > newChannelId2channel[channels[j]].GetPriority()
|
||
})
|
||
newGroup2model2channels[group][model] = channels
|
||
}
|
||
}
|
||
|
||
channelSyncLock.Lock()
|
||
group2model2channels = newGroup2model2channels
|
||
//channelsIDM = newChannelId2channel
|
||
for i, channel := range newChannelId2channel {
|
||
if channel.ChannelInfo.IsMultiKey {
|
||
channel.Keys = channel.GetKeys()
|
||
if channel.ChannelInfo.MultiKeyMode == constant.MultiKeyModePolling {
|
||
if oldChannel, ok := channelsIDM[i]; ok {
|
||
// 存在旧的渠道,如果是多key且轮询,保留轮询索引信息
|
||
if oldChannel.ChannelInfo.IsMultiKey && oldChannel.ChannelInfo.MultiKeyMode == constant.MultiKeyModePolling {
|
||
channel.ChannelInfo.MultiKeyPollingIndex = oldChannel.ChannelInfo.MultiKeyPollingIndex
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
channelsIDM = newChannelId2channel
|
||
channelSyncLock.Unlock()
|
||
logger.SysLog("channels synced from database")
|
||
}
|
||
|
||
func SyncChannelCache(frequency int) {
|
||
for {
|
||
time.Sleep(time.Duration(frequency) * time.Second)
|
||
logger.SysLog("syncing channels from database")
|
||
InitChannelCache()
|
||
}
|
||
}
|
||
|
||
func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, model string, retry int) (*Channel, string, error) {
|
||
var channel *Channel
|
||
var err error
|
||
selectGroup := group
|
||
if group == "auto" {
|
||
if len(setting.AutoGroups) == 0 {
|
||
return nil, selectGroup, errors.New("auto groups is not enabled")
|
||
}
|
||
for _, autoGroup := range setting.AutoGroups {
|
||
if common.DebugEnabled {
|
||
println("autoGroup:", autoGroup)
|
||
}
|
||
channel, _ = getRandomSatisfiedChannel(autoGroup, model, retry)
|
||
if channel == nil {
|
||
continue
|
||
} else {
|
||
c.Set("auto_group", autoGroup)
|
||
selectGroup = autoGroup
|
||
if common.DebugEnabled {
|
||
println("selectGroup:", selectGroup)
|
||
}
|
||
break
|
||
}
|
||
}
|
||
} else {
|
||
channel, err = getRandomSatisfiedChannel(group, model, retry)
|
||
if err != nil {
|
||
return nil, group, err
|
||
}
|
||
}
|
||
return channel, selectGroup, nil
|
||
}
|
||
|
||
func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
|
||
// if memory cache is disabled, get channel directly from database
|
||
if !common.MemoryCacheEnabled {
|
||
return GetRandomSatisfiedChannel(group, model, retry)
|
||
}
|
||
|
||
channelSyncLock.RLock()
|
||
defer channelSyncLock.RUnlock()
|
||
|
||
// First, try to find channels with the exact model name.
|
||
channels := group2model2channels[group][model]
|
||
|
||
// If no channels found, try to find channels with the normalized model name.
|
||
if len(channels) == 0 {
|
||
normalizedModel := ratio_setting.FormatMatchingModelName(model)
|
||
channels = group2model2channels[group][normalizedModel]
|
||
}
|
||
|
||
if len(channels) == 0 {
|
||
return nil, nil
|
||
}
|
||
|
||
if len(channels) == 1 {
|
||
if channel, ok := channelsIDM[channels[0]]; ok {
|
||
return channel, nil
|
||
}
|
||
return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channels[0])
|
||
}
|
||
|
||
uniquePriorities := make(map[int]bool)
|
||
for _, channelId := range channels {
|
||
if channel, ok := channelsIDM[channelId]; ok {
|
||
uniquePriorities[int(channel.GetPriority())] = true
|
||
} else {
|
||
return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channelId)
|
||
}
|
||
}
|
||
var sortedUniquePriorities []int
|
||
for priority := range uniquePriorities {
|
||
sortedUniquePriorities = append(sortedUniquePriorities, priority)
|
||
}
|
||
sort.Sort(sort.Reverse(sort.IntSlice(sortedUniquePriorities)))
|
||
|
||
if retry >= len(uniquePriorities) {
|
||
retry = len(uniquePriorities) - 1
|
||
}
|
||
targetPriority := int64(sortedUniquePriorities[retry])
|
||
|
||
// get the priority for the given retry number
|
||
var targetChannels []*Channel
|
||
for _, channelId := range channels {
|
||
if channel, ok := channelsIDM[channelId]; ok {
|
||
if channel.GetPriority() == targetPriority {
|
||
targetChannels = append(targetChannels, channel)
|
||
}
|
||
} else {
|
||
return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channelId)
|
||
}
|
||
}
|
||
|
||
// 平滑系数
|
||
smoothingFactor := 10
|
||
// Calculate the total weight of all channels up to endIdx
|
||
totalWeight := 0
|
||
for _, channel := range targetChannels {
|
||
totalWeight += channel.GetWeight() + smoothingFactor
|
||
}
|
||
// Generate a random value in the range [0, totalWeight)
|
||
randomWeight := rand.Intn(totalWeight)
|
||
|
||
// Find a channel based on its weight
|
||
for _, channel := range targetChannels {
|
||
randomWeight -= channel.GetWeight() + smoothingFactor
|
||
if randomWeight < 0 {
|
||
return channel, nil
|
||
}
|
||
}
|
||
// return null if no channel is not found
|
||
return nil, errors.New("channel not found")
|
||
}
|
||
|
||
func CacheGetChannel(id int) (*Channel, error) {
|
||
if !common.MemoryCacheEnabled {
|
||
return GetChannelById(id, true)
|
||
}
|
||
channelSyncLock.RLock()
|
||
defer channelSyncLock.RUnlock()
|
||
|
||
c, ok := channelsIDM[id]
|
||
if !ok {
|
||
return nil, fmt.Errorf("渠道# %d,已不存在", id)
|
||
}
|
||
return c, nil
|
||
}
|
||
|
||
func CacheGetChannelInfo(id int) (*ChannelInfo, error) {
|
||
if !common.MemoryCacheEnabled {
|
||
channel, err := GetChannelById(id, true)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &channel.ChannelInfo, nil
|
||
}
|
||
channelSyncLock.RLock()
|
||
defer channelSyncLock.RUnlock()
|
||
|
||
c, ok := channelsIDM[id]
|
||
if !ok {
|
||
return nil, fmt.Errorf("渠道# %d,已不存在", id)
|
||
}
|
||
return &c.ChannelInfo, nil
|
||
}
|
||
|
||
func CacheUpdateChannelStatus(id int, status int) {
|
||
if !common.MemoryCacheEnabled {
|
||
return
|
||
}
|
||
channelSyncLock.Lock()
|
||
defer channelSyncLock.Unlock()
|
||
if channel, ok := channelsIDM[id]; ok {
|
||
channel.Status = status
|
||
}
|
||
if status != common.ChannelStatusEnabled {
|
||
// delete the channel from group2model2channels
|
||
for group, model2channels := range group2model2channels {
|
||
for model, channels := range model2channels {
|
||
for i, channelId := range channels {
|
||
if channelId == id {
|
||
// remove the channel from the slice
|
||
group2model2channels[group][model] = append(channels[:i], channels[i+1:]...)
|
||
break
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
func CacheUpdateChannel(channel *Channel) {
|
||
if !common.MemoryCacheEnabled {
|
||
return
|
||
}
|
||
channelSyncLock.Lock()
|
||
defer channelSyncLock.Unlock()
|
||
if channel == nil {
|
||
return
|
||
}
|
||
|
||
println("CacheUpdateChannel:", channel.Id, channel.Name, channel.Status, channel.ChannelInfo.MultiKeyPollingIndex)
|
||
|
||
println("before:", channelsIDM[channel.Id].ChannelInfo.MultiKeyPollingIndex)
|
||
channelsIDM[channel.Id] = channel
|
||
println("after :", channelsIDM[channel.Id].ChannelInfo.MultiKeyPollingIndex)
|
||
}
|