Files
new-api/model/channel_cache.go
CaIon e2037ad756 refactor: Introduce pre-consume quota and unify relay handlers
This commit introduces a major architectural refactoring to improve quota management, centralize logging, and streamline the relay handling logic.

Key changes:
- **Pre-consume Quota:** Implements a new mechanism to check and reserve user quota *before* making the request to the upstream provider. This ensures more accurate quota deduction and prevents users from exceeding their limits due to concurrent requests.

- **Unified Relay Handlers:** Refactors the relay logic to use generic handlers (e.g., `ChatHandler`, `ImageHandler`) instead of provider-specific implementations. This significantly reduces code duplication and simplifies adding new channels.

- **Centralized Logger:** A new dedicated `logger` package is introduced, and all system logging calls are migrated to use it, moving this responsibility out of the `common` package.

- **Code Reorganization:** DTOs are generalized (e.g., `dalle.go` -> `openai_image.go`) and utility code is moved to more appropriate packages (e.g., `common/http.go` -> `service/http.go`) for better code structure.
2025-08-14 20:05:06 +08:00

286 lines
7.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package model
import (
"errors"
"fmt"
"math/rand"
"one-api/common"
"one-api/constant"
"one-api/logger"
"one-api/setting"
"one-api/setting/ratio_setting"
"sort"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
)
var group2model2channels map[string]map[string][]int // enabled channel
var channelsIDM map[int]*Channel // all channels include disabled
var channelSyncLock sync.RWMutex
func InitChannelCache() {
if !common.MemoryCacheEnabled {
return
}
newChannelId2channel := make(map[int]*Channel)
var channels []*Channel
DB.Find(&channels)
for _, channel := range channels {
newChannelId2channel[channel.Id] = channel
}
var abilities []*Ability
DB.Find(&abilities)
groups := make(map[string]bool)
for _, ability := range abilities {
groups[ability.Group] = true
}
newGroup2model2channels := make(map[string]map[string][]int)
for group := range groups {
newGroup2model2channels[group] = make(map[string][]int)
}
for _, channel := range channels {
if channel.Status != common.ChannelStatusEnabled {
continue // skip disabled channels
}
groups := strings.Split(channel.Group, ",")
for _, group := range groups {
models := strings.Split(channel.Models, ",")
for _, model := range models {
if _, ok := newGroup2model2channels[group][model]; !ok {
newGroup2model2channels[group][model] = make([]int, 0)
}
newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel.Id)
}
}
}
// sort by priority
for group, model2channels := range newGroup2model2channels {
for model, channels := range model2channels {
sort.Slice(channels, func(i, j int) bool {
return newChannelId2channel[channels[i]].GetPriority() > newChannelId2channel[channels[j]].GetPriority()
})
newGroup2model2channels[group][model] = channels
}
}
channelSyncLock.Lock()
group2model2channels = newGroup2model2channels
//channelsIDM = newChannelId2channel
for i, channel := range newChannelId2channel {
if channel.ChannelInfo.IsMultiKey {
channel.Keys = channel.GetKeys()
if channel.ChannelInfo.MultiKeyMode == constant.MultiKeyModePolling {
if oldChannel, ok := channelsIDM[i]; ok {
// 存在旧的渠道如果是多key且轮询保留轮询索引信息
if oldChannel.ChannelInfo.IsMultiKey && oldChannel.ChannelInfo.MultiKeyMode == constant.MultiKeyModePolling {
channel.ChannelInfo.MultiKeyPollingIndex = oldChannel.ChannelInfo.MultiKeyPollingIndex
}
}
}
}
}
channelsIDM = newChannelId2channel
channelSyncLock.Unlock()
logger.SysLog("channels synced from database")
}
func SyncChannelCache(frequency int) {
for {
time.Sleep(time.Duration(frequency) * time.Second)
logger.SysLog("syncing channels from database")
InitChannelCache()
}
}
func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, model string, retry int) (*Channel, string, error) {
var channel *Channel
var err error
selectGroup := group
if group == "auto" {
if len(setting.AutoGroups) == 0 {
return nil, selectGroup, errors.New("auto groups is not enabled")
}
for _, autoGroup := range setting.AutoGroups {
if common.DebugEnabled {
println("autoGroup:", autoGroup)
}
channel, _ = getRandomSatisfiedChannel(autoGroup, model, retry)
if channel == nil {
continue
} else {
c.Set("auto_group", autoGroup)
selectGroup = autoGroup
if common.DebugEnabled {
println("selectGroup:", selectGroup)
}
break
}
}
} else {
channel, err = getRandomSatisfiedChannel(group, model, retry)
if err != nil {
return nil, group, err
}
}
return channel, selectGroup, nil
}
func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
// if memory cache is disabled, get channel directly from database
if !common.MemoryCacheEnabled {
return GetRandomSatisfiedChannel(group, model, retry)
}
channelSyncLock.RLock()
defer channelSyncLock.RUnlock()
// First, try to find channels with the exact model name.
channels := group2model2channels[group][model]
// If no channels found, try to find channels with the normalized model name.
if len(channels) == 0 {
normalizedModel := ratio_setting.FormatMatchingModelName(model)
channels = group2model2channels[group][normalizedModel]
}
if len(channels) == 0 {
return nil, nil
}
if len(channels) == 1 {
if channel, ok := channelsIDM[channels[0]]; ok {
return channel, nil
}
return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channels[0])
}
uniquePriorities := make(map[int]bool)
for _, channelId := range channels {
if channel, ok := channelsIDM[channelId]; ok {
uniquePriorities[int(channel.GetPriority())] = true
} else {
return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channelId)
}
}
var sortedUniquePriorities []int
for priority := range uniquePriorities {
sortedUniquePriorities = append(sortedUniquePriorities, priority)
}
sort.Sort(sort.Reverse(sort.IntSlice(sortedUniquePriorities)))
if retry >= len(uniquePriorities) {
retry = len(uniquePriorities) - 1
}
targetPriority := int64(sortedUniquePriorities[retry])
// get the priority for the given retry number
var targetChannels []*Channel
for _, channelId := range channels {
if channel, ok := channelsIDM[channelId]; ok {
if channel.GetPriority() == targetPriority {
targetChannels = append(targetChannels, channel)
}
} else {
return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channelId)
}
}
// 平滑系数
smoothingFactor := 10
// Calculate the total weight of all channels up to endIdx
totalWeight := 0
for _, channel := range targetChannels {
totalWeight += channel.GetWeight() + smoothingFactor
}
// Generate a random value in the range [0, totalWeight)
randomWeight := rand.Intn(totalWeight)
// Find a channel based on its weight
for _, channel := range targetChannels {
randomWeight -= channel.GetWeight() + smoothingFactor
if randomWeight < 0 {
return channel, nil
}
}
// return null if no channel is not found
return nil, errors.New("channel not found")
}
func CacheGetChannel(id int) (*Channel, error) {
if !common.MemoryCacheEnabled {
return GetChannelById(id, true)
}
channelSyncLock.RLock()
defer channelSyncLock.RUnlock()
c, ok := channelsIDM[id]
if !ok {
return nil, fmt.Errorf("渠道# %d已不存在", id)
}
return c, nil
}
func CacheGetChannelInfo(id int) (*ChannelInfo, error) {
if !common.MemoryCacheEnabled {
channel, err := GetChannelById(id, true)
if err != nil {
return nil, err
}
return &channel.ChannelInfo, nil
}
channelSyncLock.RLock()
defer channelSyncLock.RUnlock()
c, ok := channelsIDM[id]
if !ok {
return nil, fmt.Errorf("渠道# %d已不存在", id)
}
return &c.ChannelInfo, nil
}
func CacheUpdateChannelStatus(id int, status int) {
if !common.MemoryCacheEnabled {
return
}
channelSyncLock.Lock()
defer channelSyncLock.Unlock()
if channel, ok := channelsIDM[id]; ok {
channel.Status = status
}
if status != common.ChannelStatusEnabled {
// delete the channel from group2model2channels
for group, model2channels := range group2model2channels {
for model, channels := range model2channels {
for i, channelId := range channels {
if channelId == id {
// remove the channel from the slice
group2model2channels[group][model] = append(channels[:i], channels[i+1:]...)
break
}
}
}
}
}
}
func CacheUpdateChannel(channel *Channel) {
if !common.MemoryCacheEnabled {
return
}
channelSyncLock.Lock()
defer channelSyncLock.Unlock()
if channel == nil {
return
}
println("CacheUpdateChannel:", channel.Id, channel.Name, channel.Status, channel.ChannelInfo.MultiKeyPollingIndex)
println("before:", channelsIDM[channel.Id].ChannelInfo.MultiKeyPollingIndex)
channelsIDM[channel.Id] = channel
println("after :", channelsIDM[channel.Id].ChannelInfo.MultiKeyPollingIndex)
}