This commit refactors the logging mechanism across the application by replacing direct logger calls with a centralized logging approach using the `common` package. Key changes include: - Replaced instances of `logger.SysLog` and `logger.FatalLog` with `common.SysLog` and `common.FatalLog` for consistent logging practices. - Updated resource initialization error handling to utilize the new logging structure, enhancing maintainability and readability. - Minor adjustments to improve code clarity and organization throughout various modules. This change aims to streamline logging and improve the overall architecture of the codebase.
962 lines
27 KiB
Go
962 lines
27 KiB
Go
package model
|
||
|
||
import (
|
||
"database/sql/driver"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"math/rand"
|
||
"one-api/common"
|
||
"one-api/constant"
|
||
"one-api/dto"
|
||
"one-api/types"
|
||
"strings"
|
||
"sync"
|
||
|
||
"github.com/samber/lo"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
type Channel struct {
|
||
Id int `json:"id"`
|
||
Type int `json:"type" gorm:"default:0"`
|
||
Key string `json:"key" gorm:"not null"`
|
||
OpenAIOrganization *string `json:"openai_organization"`
|
||
TestModel *string `json:"test_model"`
|
||
Status int `json:"status" gorm:"default:1"`
|
||
Name string `json:"name" gorm:"index"`
|
||
Weight *uint `json:"weight" gorm:"default:0"`
|
||
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||
TestTime int64 `json:"test_time" gorm:"bigint"`
|
||
ResponseTime int `json:"response_time"` // in milliseconds
|
||
BaseURL *string `json:"base_url" gorm:"column:base_url;default:''"`
|
||
Other string `json:"other"`
|
||
Balance float64 `json:"balance"` // in USD
|
||
BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"`
|
||
Models string `json:"models"`
|
||
Group string `json:"group" gorm:"type:varchar(64);default:'default'"`
|
||
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
|
||
ModelMapping *string `json:"model_mapping" gorm:"type:text"`
|
||
//MaxInputTokens *int `json:"max_input_tokens" gorm:"default:0"`
|
||
StatusCodeMapping *string `json:"status_code_mapping" gorm:"type:varchar(1024);default:''"`
|
||
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
|
||
AutoBan *int `json:"auto_ban" gorm:"default:1"`
|
||
OtherInfo string `json:"other_info"`
|
||
OtherSettings string `json:"settings" gorm:"column:settings"` // 其他设置
|
||
Tag *string `json:"tag" gorm:"index"`
|
||
Setting *string `json:"setting" gorm:"type:text"` // 渠道额外设置
|
||
ParamOverride *string `json:"param_override" gorm:"type:text"`
|
||
// add after v0.8.5
|
||
ChannelInfo ChannelInfo `json:"channel_info" gorm:"type:json"`
|
||
|
||
// cache info
|
||
Keys []string `json:"-" gorm:"-"`
|
||
}
|
||
|
||
type ChannelInfo struct {
|
||
IsMultiKey bool `json:"is_multi_key"` // 是否多Key模式
|
||
MultiKeySize int `json:"multi_key_size"` // 多Key模式下的Key数量
|
||
MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表,key index -> status
|
||
MultiKeyDisabledReason map[int]string `json:"multi_key_disabled_reason,omitempty"` // key禁用原因列表,key index -> reason
|
||
MultiKeyDisabledTime map[int]int64 `json:"multi_key_disabled_time,omitempty"` // key禁用时间列表,key index -> time
|
||
MultiKeyPollingIndex int `json:"multi_key_polling_index"` // 多Key模式下轮询的key索引
|
||
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
|
||
}
|
||
|
||
// Value implements driver.Valuer interface
|
||
func (c ChannelInfo) Value() (driver.Value, error) {
|
||
return common.Marshal(&c)
|
||
}
|
||
|
||
// Scan implements sql.Scanner interface
|
||
func (c *ChannelInfo) Scan(value interface{}) error {
|
||
bytesValue, _ := value.([]byte)
|
||
return common.Unmarshal(bytesValue, c)
|
||
}
|
||
|
||
func (channel *Channel) GetKeys() []string {
|
||
if channel.Key == "" {
|
||
return []string{}
|
||
}
|
||
if len(channel.Keys) > 0 {
|
||
return channel.Keys
|
||
}
|
||
trimmed := strings.TrimSpace(channel.Key)
|
||
// If the key starts with '[', try to parse it as a JSON array (e.g., for Vertex AI scenarios)
|
||
if strings.HasPrefix(trimmed, "[") {
|
||
var arr []json.RawMessage
|
||
if err := common.Unmarshal([]byte(trimmed), &arr); err == nil {
|
||
res := make([]string, len(arr))
|
||
for i, v := range arr {
|
||
res[i] = string(v)
|
||
}
|
||
return res
|
||
}
|
||
}
|
||
// Otherwise, fall back to splitting by newline
|
||
keys := strings.Split(strings.Trim(channel.Key, "\n"), "\n")
|
||
return keys
|
||
}
|
||
|
||
func (channel *Channel) GetNextEnabledKey() (string, int, *types.NewAPIError) {
|
||
// If not in multi-key mode, return the original key string directly.
|
||
if !channel.ChannelInfo.IsMultiKey {
|
||
return channel.Key, 0, nil
|
||
}
|
||
|
||
// Obtain all keys (split by \n)
|
||
keys := channel.GetKeys()
|
||
if len(keys) == 0 {
|
||
// No keys available, return error, should disable the channel
|
||
return "", 0, types.NewError(errors.New("no keys available"), types.ErrorCodeChannelNoAvailableKey)
|
||
}
|
||
|
||
statusList := channel.ChannelInfo.MultiKeyStatusList
|
||
// helper to get key status, default to enabled when missing
|
||
getStatus := func(idx int) int {
|
||
if statusList == nil {
|
||
return common.ChannelStatusEnabled
|
||
}
|
||
if status, ok := statusList[idx]; ok {
|
||
return status
|
||
}
|
||
return common.ChannelStatusEnabled
|
||
}
|
||
|
||
// Collect indexes of enabled keys
|
||
enabledIdx := make([]int, 0, len(keys))
|
||
for i := range keys {
|
||
if getStatus(i) == common.ChannelStatusEnabled {
|
||
enabledIdx = append(enabledIdx, i)
|
||
}
|
||
}
|
||
// If no specific status list or none enabled, fall back to first key
|
||
if len(enabledIdx) == 0 {
|
||
return keys[0], 0, nil
|
||
}
|
||
|
||
switch channel.ChannelInfo.MultiKeyMode {
|
||
case constant.MultiKeyModeRandom:
|
||
// Randomly pick one enabled key
|
||
selectedIdx := enabledIdx[rand.Intn(len(enabledIdx))]
|
||
return keys[selectedIdx], selectedIdx, nil
|
||
case constant.MultiKeyModePolling:
|
||
// Use channel-specific lock to ensure thread-safe polling
|
||
lock := GetChannelPollingLock(channel.Id)
|
||
lock.Lock()
|
||
defer lock.Unlock()
|
||
|
||
channelInfo, err := CacheGetChannelInfo(channel.Id)
|
||
if err != nil {
|
||
return "", 0, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
||
}
|
||
//println("before polling index:", channel.ChannelInfo.MultiKeyPollingIndex)
|
||
defer func() {
|
||
if common.DebugEnabled {
|
||
println(fmt.Sprintf("channel %d polling index: %d", channel.Id, channel.ChannelInfo.MultiKeyPollingIndex))
|
||
}
|
||
if !common.MemoryCacheEnabled {
|
||
_ = channel.SaveChannelInfo()
|
||
} else {
|
||
// CacheUpdateChannel(channel)
|
||
}
|
||
}()
|
||
// Start from the saved polling index and look for the next enabled key
|
||
start := channelInfo.MultiKeyPollingIndex
|
||
if start < 0 || start >= len(keys) {
|
||
start = 0
|
||
}
|
||
for i := 0; i < len(keys); i++ {
|
||
idx := (start + i) % len(keys)
|
||
if getStatus(idx) == common.ChannelStatusEnabled {
|
||
// update polling index for next call (point to the next position)
|
||
channel.ChannelInfo.MultiKeyPollingIndex = (idx + 1) % len(keys)
|
||
return keys[idx], idx, nil
|
||
}
|
||
}
|
||
// Fallback – should not happen, but return first enabled key
|
||
return keys[enabledIdx[0]], enabledIdx[0], nil
|
||
default:
|
||
// Unknown mode, default to first enabled key (or original key string)
|
||
return keys[enabledIdx[0]], enabledIdx[0], nil
|
||
}
|
||
}
|
||
|
||
func (channel *Channel) SaveChannelInfo() error {
|
||
return DB.Model(channel).Update("channel_info", channel.ChannelInfo).Error
|
||
}
|
||
|
||
func (channel *Channel) GetModels() []string {
|
||
if channel.Models == "" {
|
||
return []string{}
|
||
}
|
||
return strings.Split(strings.Trim(channel.Models, ","), ",")
|
||
}
|
||
|
||
func (channel *Channel) GetGroups() []string {
|
||
if channel.Group == "" {
|
||
return []string{}
|
||
}
|
||
groups := strings.Split(strings.Trim(channel.Group, ","), ",")
|
||
for i, group := range groups {
|
||
groups[i] = strings.TrimSpace(group)
|
||
}
|
||
return groups
|
||
}
|
||
|
||
func (channel *Channel) GetOtherInfo() map[string]interface{} {
|
||
otherInfo := make(map[string]interface{})
|
||
if channel.OtherInfo != "" {
|
||
err := common.Unmarshal([]byte(channel.OtherInfo), &otherInfo)
|
||
if err != nil {
|
||
common.SysLog("failed to unmarshal other info: " + err.Error())
|
||
}
|
||
}
|
||
return otherInfo
|
||
}
|
||
|
||
func (channel *Channel) SetOtherInfo(otherInfo map[string]interface{}) {
|
||
otherInfoBytes, err := json.Marshal(otherInfo)
|
||
if err != nil {
|
||
common.SysLog("failed to marshal other info: " + err.Error())
|
||
return
|
||
}
|
||
channel.OtherInfo = string(otherInfoBytes)
|
||
}
|
||
|
||
func (channel *Channel) GetTag() string {
|
||
if channel.Tag == nil {
|
||
return ""
|
||
}
|
||
return *channel.Tag
|
||
}
|
||
|
||
func (channel *Channel) SetTag(tag string) {
|
||
channel.Tag = &tag
|
||
}
|
||
|
||
func (channel *Channel) GetAutoBan() bool {
|
||
if channel.AutoBan == nil {
|
||
return false
|
||
}
|
||
return *channel.AutoBan == 1
|
||
}
|
||
|
||
func (channel *Channel) Save() error {
|
||
return DB.Save(channel).Error
|
||
}
|
||
|
||
func GetAllChannels(startIdx int, num int, selectAll bool, idSort bool) ([]*Channel, error) {
|
||
var channels []*Channel
|
||
var err error
|
||
order := "priority desc"
|
||
if idSort {
|
||
order = "id desc"
|
||
}
|
||
if selectAll {
|
||
err = DB.Order(order).Find(&channels).Error
|
||
} else {
|
||
err = DB.Order(order).Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
|
||
}
|
||
return channels, err
|
||
}
|
||
|
||
func GetChannelsByTag(tag string, idSort bool) ([]*Channel, error) {
|
||
var channels []*Channel
|
||
order := "priority desc"
|
||
if idSort {
|
||
order = "id desc"
|
||
}
|
||
err := DB.Where("tag = ?", tag).Order(order).Find(&channels).Error
|
||
return channels, err
|
||
}
|
||
|
||
func SearchChannels(keyword string, group string, model string, idSort bool) ([]*Channel, error) {
|
||
var channels []*Channel
|
||
modelsCol := "`models`"
|
||
|
||
// 如果是 PostgreSQL,使用双引号
|
||
if common.UsingPostgreSQL {
|
||
modelsCol = `"models"`
|
||
}
|
||
|
||
baseURLCol := "`base_url`"
|
||
// 如果是 PostgreSQL,使用双引号
|
||
if common.UsingPostgreSQL {
|
||
baseURLCol = `"base_url"`
|
||
}
|
||
|
||
order := "priority desc"
|
||
if idSort {
|
||
order = "id desc"
|
||
}
|
||
|
||
// 构造基础查询
|
||
baseQuery := DB.Model(&Channel{}).Omit("key")
|
||
|
||
// 构造WHERE子句
|
||
var whereClause string
|
||
var args []interface{}
|
||
if group != "" && group != "null" {
|
||
var groupCondition string
|
||
if common.UsingMySQL {
|
||
groupCondition = `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ?`
|
||
} else {
|
||
// sqlite, PostgreSQL
|
||
groupCondition = `(',' || ` + commonGroupCol + ` || ',') LIKE ?`
|
||
}
|
||
whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
|
||
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%")
|
||
} else {
|
||
whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
|
||
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%")
|
||
}
|
||
|
||
// 执行查询
|
||
err := baseQuery.Where(whereClause, args...).Order(order).Find(&channels).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return channels, nil
|
||
}
|
||
|
||
func GetChannelById(id int, selectAll bool) (*Channel, error) {
|
||
channel := &Channel{Id: id}
|
||
var err error = nil
|
||
if selectAll {
|
||
err = DB.First(channel, "id = ?", id).Error
|
||
} else {
|
||
err = DB.Omit("key").First(channel, "id = ?", id).Error
|
||
}
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if channel == nil {
|
||
return nil, errors.New("channel not found")
|
||
}
|
||
return channel, nil
|
||
}
|
||
|
||
func BatchInsertChannels(channels []Channel) error {
|
||
if len(channels) == 0 {
|
||
return nil
|
||
}
|
||
tx := DB.Begin()
|
||
if tx.Error != nil {
|
||
return tx.Error
|
||
}
|
||
defer func() {
|
||
if r := recover(); r != nil {
|
||
tx.Rollback()
|
||
}
|
||
}()
|
||
|
||
for _, chunk := range lo.Chunk(channels, 50) {
|
||
if err := tx.Create(&chunk).Error; err != nil {
|
||
tx.Rollback()
|
||
return err
|
||
}
|
||
for _, channel_ := range chunk {
|
||
if err := channel_.AddAbilities(tx); err != nil {
|
||
tx.Rollback()
|
||
return err
|
||
}
|
||
}
|
||
}
|
||
return tx.Commit().Error
|
||
}
|
||
|
||
func BatchDeleteChannels(ids []int) error {
|
||
if len(ids) == 0 {
|
||
return nil
|
||
}
|
||
// 使用事务 分批删除channel表和abilities表
|
||
tx := DB.Begin()
|
||
if tx.Error != nil {
|
||
return tx.Error
|
||
}
|
||
for _, chunk := range lo.Chunk(ids, 200) {
|
||
if err := tx.Where("id in (?)", chunk).Delete(&Channel{}).Error; err != nil {
|
||
tx.Rollback()
|
||
return err
|
||
}
|
||
if err := tx.Where("channel_id in (?)", chunk).Delete(&Ability{}).Error; err != nil {
|
||
tx.Rollback()
|
||
return err
|
||
}
|
||
}
|
||
return tx.Commit().Error
|
||
}
|
||
|
||
func (channel *Channel) GetPriority() int64 {
|
||
if channel.Priority == nil {
|
||
return 0
|
||
}
|
||
return *channel.Priority
|
||
}
|
||
|
||
func (channel *Channel) GetWeight() int {
|
||
if channel.Weight == nil {
|
||
return 0
|
||
}
|
||
return int(*channel.Weight)
|
||
}
|
||
|
||
func (channel *Channel) GetBaseURL() string {
|
||
if channel.BaseURL == nil {
|
||
return ""
|
||
}
|
||
return *channel.BaseURL
|
||
}
|
||
|
||
func (channel *Channel) GetModelMapping() string {
|
||
if channel.ModelMapping == nil {
|
||
return ""
|
||
}
|
||
return *channel.ModelMapping
|
||
}
|
||
|
||
func (channel *Channel) GetStatusCodeMapping() string {
|
||
if channel.StatusCodeMapping == nil {
|
||
return ""
|
||
}
|
||
return *channel.StatusCodeMapping
|
||
}
|
||
|
||
func (channel *Channel) Insert() error {
|
||
var err error
|
||
err = DB.Create(channel).Error
|
||
if err != nil {
|
||
return err
|
||
}
|
||
err = channel.AddAbilities(nil)
|
||
return err
|
||
}
|
||
|
||
func (channel *Channel) Update() error {
|
||
// If this is a multi-key channel, recalculate MultiKeySize based on the current key list to avoid inconsistency after editing keys
|
||
if channel.ChannelInfo.IsMultiKey {
|
||
var keyStr string
|
||
if channel.Key != "" {
|
||
keyStr = channel.Key
|
||
} else {
|
||
// If key is not provided, read the existing key from the database
|
||
if existing, err := GetChannelById(channel.Id, true); err == nil {
|
||
keyStr = existing.Key
|
||
}
|
||
}
|
||
// Parse the key list (supports newline separation or JSON array)
|
||
keys := []string{}
|
||
if keyStr != "" {
|
||
trimmed := strings.TrimSpace(keyStr)
|
||
if strings.HasPrefix(trimmed, "[") {
|
||
var arr []json.RawMessage
|
||
if err := common.Unmarshal([]byte(trimmed), &arr); err == nil {
|
||
keys = make([]string, len(arr))
|
||
for i, v := range arr {
|
||
keys[i] = string(v)
|
||
}
|
||
}
|
||
}
|
||
if len(keys) == 0 { // fallback to newline split
|
||
keys = strings.Split(strings.Trim(keyStr, "\n"), "\n")
|
||
}
|
||
}
|
||
channel.ChannelInfo.MultiKeySize = len(keys)
|
||
// Clean up status data that exceeds the new key count to prevent index out of range
|
||
if channel.ChannelInfo.MultiKeyStatusList != nil {
|
||
for idx := range channel.ChannelInfo.MultiKeyStatusList {
|
||
if idx >= channel.ChannelInfo.MultiKeySize {
|
||
delete(channel.ChannelInfo.MultiKeyStatusList, idx)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
var err error
|
||
err = DB.Model(channel).Updates(channel).Error
|
||
if err != nil {
|
||
return err
|
||
}
|
||
DB.Model(channel).First(channel, "id = ?", channel.Id)
|
||
err = channel.UpdateAbilities(nil)
|
||
return err
|
||
}
|
||
|
||
func (channel *Channel) UpdateResponseTime(responseTime int64) {
|
||
err := DB.Model(channel).Select("response_time", "test_time").Updates(Channel{
|
||
TestTime: common.GetTimestamp(),
|
||
ResponseTime: int(responseTime),
|
||
}).Error
|
||
if err != nil {
|
||
common.SysLog("failed to update response time: " + err.Error())
|
||
}
|
||
}
|
||
|
||
func (channel *Channel) UpdateBalance(balance float64) {
|
||
err := DB.Model(channel).Select("balance_updated_time", "balance").Updates(Channel{
|
||
BalanceUpdatedTime: common.GetTimestamp(),
|
||
Balance: balance,
|
||
}).Error
|
||
if err != nil {
|
||
common.SysLog("failed to update balance: " + err.Error())
|
||
}
|
||
}
|
||
|
||
func (channel *Channel) Delete() error {
|
||
var err error
|
||
err = DB.Delete(channel).Error
|
||
if err != nil {
|
||
return err
|
||
}
|
||
err = channel.DeleteAbilities()
|
||
return err
|
||
}
|
||
|
||
var channelStatusLock sync.Mutex
|
||
|
||
// channelPollingLocks stores locks for each channel.id to ensure thread-safe polling
|
||
var channelPollingLocks sync.Map
|
||
|
||
// GetChannelPollingLock returns or creates a mutex for the given channel ID
|
||
func GetChannelPollingLock(channelId int) *sync.Mutex {
|
||
if lock, exists := channelPollingLocks.Load(channelId); exists {
|
||
return lock.(*sync.Mutex)
|
||
}
|
||
// Create new lock for this channel
|
||
newLock := &sync.Mutex{}
|
||
actual, _ := channelPollingLocks.LoadOrStore(channelId, newLock)
|
||
return actual.(*sync.Mutex)
|
||
}
|
||
|
||
// CleanupChannelPollingLocks removes locks for channels that no longer exist
|
||
// This is optional and can be called periodically to prevent memory leaks
|
||
func CleanupChannelPollingLocks() {
|
||
var activeChannelIds []int
|
||
DB.Model(&Channel{}).Pluck("id", &activeChannelIds)
|
||
|
||
activeChannelSet := make(map[int]bool)
|
||
for _, id := range activeChannelIds {
|
||
activeChannelSet[id] = true
|
||
}
|
||
|
||
channelPollingLocks.Range(func(key, value interface{}) bool {
|
||
channelId := key.(int)
|
||
if !activeChannelSet[channelId] {
|
||
channelPollingLocks.Delete(channelId)
|
||
}
|
||
return true
|
||
})
|
||
}
|
||
|
||
func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int, reason string) {
|
||
keys := channel.GetKeys()
|
||
if len(keys) == 0 {
|
||
channel.Status = status
|
||
} else {
|
||
var keyIndex int
|
||
for i, key := range keys {
|
||
if key == usingKey {
|
||
keyIndex = i
|
||
break
|
||
}
|
||
}
|
||
if channel.ChannelInfo.MultiKeyStatusList == nil {
|
||
channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
|
||
}
|
||
if status == common.ChannelStatusEnabled {
|
||
delete(channel.ChannelInfo.MultiKeyStatusList, keyIndex)
|
||
} else {
|
||
channel.ChannelInfo.MultiKeyStatusList[keyIndex] = status
|
||
if channel.ChannelInfo.MultiKeyDisabledReason == nil {
|
||
channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
|
||
}
|
||
if channel.ChannelInfo.MultiKeyDisabledTime == nil {
|
||
channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
|
||
}
|
||
channel.ChannelInfo.MultiKeyDisabledReason[keyIndex] = reason
|
||
channel.ChannelInfo.MultiKeyDisabledTime[keyIndex] = common.GetTimestamp()
|
||
}
|
||
if len(channel.ChannelInfo.MultiKeyStatusList) >= channel.ChannelInfo.MultiKeySize {
|
||
channel.Status = common.ChannelStatusAutoDisabled
|
||
info := channel.GetOtherInfo()
|
||
info["status_reason"] = "All keys are disabled"
|
||
info["status_time"] = common.GetTimestamp()
|
||
channel.SetOtherInfo(info)
|
||
}
|
||
}
|
||
}
|
||
|
||
func UpdateChannelStatus(channelId int, usingKey string, status int, reason string) bool {
|
||
if common.MemoryCacheEnabled {
|
||
channelStatusLock.Lock()
|
||
defer channelStatusLock.Unlock()
|
||
|
||
channelCache, _ := CacheGetChannel(channelId)
|
||
if channelCache == nil {
|
||
return false
|
||
}
|
||
if channelCache.ChannelInfo.IsMultiKey {
|
||
// 如果是多Key模式,更新缓存中的状态
|
||
handlerMultiKeyUpdate(channelCache, usingKey, status, reason)
|
||
//CacheUpdateChannel(channelCache)
|
||
//return true
|
||
} else {
|
||
// 如果缓存渠道存在,且状态已是目标状态,直接返回
|
||
if channelCache.Status == status {
|
||
return false
|
||
}
|
||
CacheUpdateChannelStatus(channelId, status)
|
||
}
|
||
}
|
||
|
||
shouldUpdateAbilities := false
|
||
defer func() {
|
||
if shouldUpdateAbilities {
|
||
err := UpdateAbilityStatus(channelId, status == common.ChannelStatusEnabled)
|
||
if err != nil {
|
||
common.SysLog("failed to update ability status: " + err.Error())
|
||
}
|
||
}
|
||
}()
|
||
channel, err := GetChannelById(channelId, true)
|
||
if err != nil {
|
||
return false
|
||
} else {
|
||
if channel.Status == status {
|
||
return false
|
||
}
|
||
|
||
if channel.ChannelInfo.IsMultiKey {
|
||
beforeStatus := channel.Status
|
||
handlerMultiKeyUpdate(channel, usingKey, status, reason)
|
||
if beforeStatus != channel.Status {
|
||
shouldUpdateAbilities = true
|
||
}
|
||
} else {
|
||
info := channel.GetOtherInfo()
|
||
info["status_reason"] = reason
|
||
info["status_time"] = common.GetTimestamp()
|
||
channel.SetOtherInfo(info)
|
||
channel.Status = status
|
||
shouldUpdateAbilities = true
|
||
}
|
||
err = channel.Save()
|
||
if err != nil {
|
||
common.SysLog("failed to update channel status: " + err.Error())
|
||
return false
|
||
}
|
||
}
|
||
return true
|
||
}
|
||
|
||
func EnableChannelByTag(tag string) error {
|
||
err := DB.Model(&Channel{}).Where("tag = ?", tag).Update("status", common.ChannelStatusEnabled).Error
|
||
if err != nil {
|
||
return err
|
||
}
|
||
err = UpdateAbilityStatusByTag(tag, true)
|
||
return err
|
||
}
|
||
|
||
func DisableChannelByTag(tag string) error {
|
||
err := DB.Model(&Channel{}).Where("tag = ?", tag).Update("status", common.ChannelStatusManuallyDisabled).Error
|
||
if err != nil {
|
||
return err
|
||
}
|
||
err = UpdateAbilityStatusByTag(tag, false)
|
||
return err
|
||
}
|
||
|
||
func EditChannelByTag(tag string, newTag *string, modelMapping *string, models *string, group *string, priority *int64, weight *uint) error {
|
||
updateData := Channel{}
|
||
shouldReCreateAbilities := false
|
||
updatedTag := tag
|
||
// 如果 newTag 不为空且不等于 tag,则更新 tag
|
||
if newTag != nil && *newTag != tag {
|
||
updateData.Tag = newTag
|
||
updatedTag = *newTag
|
||
}
|
||
if modelMapping != nil && *modelMapping != "" {
|
||
updateData.ModelMapping = modelMapping
|
||
}
|
||
if models != nil && *models != "" {
|
||
shouldReCreateAbilities = true
|
||
updateData.Models = *models
|
||
}
|
||
if group != nil && *group != "" {
|
||
shouldReCreateAbilities = true
|
||
updateData.Group = *group
|
||
}
|
||
if priority != nil {
|
||
updateData.Priority = priority
|
||
}
|
||
if weight != nil {
|
||
updateData.Weight = weight
|
||
}
|
||
|
||
err := DB.Model(&Channel{}).Where("tag = ?", tag).Updates(updateData).Error
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if shouldReCreateAbilities {
|
||
channels, err := GetChannelsByTag(updatedTag, false)
|
||
if err == nil {
|
||
for _, channel := range channels {
|
||
err = channel.UpdateAbilities(nil)
|
||
if err != nil {
|
||
common.SysLog("failed to update abilities: " + err.Error())
|
||
}
|
||
}
|
||
}
|
||
} else {
|
||
err := UpdateAbilityByTag(tag, newTag, priority, weight)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func UpdateChannelUsedQuota(id int, quota int) {
|
||
if common.BatchUpdateEnabled {
|
||
addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota)
|
||
return
|
||
}
|
||
updateChannelUsedQuota(id, quota)
|
||
}
|
||
|
||
func updateChannelUsedQuota(id int, quota int) {
|
||
err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
|
||
if err != nil {
|
||
common.SysLog("failed to update channel used quota: " + err.Error())
|
||
}
|
||
}
|
||
|
||
func DeleteChannelByStatus(status int64) (int64, error) {
|
||
result := DB.Where("status = ?", status).Delete(&Channel{})
|
||
return result.RowsAffected, result.Error
|
||
}
|
||
|
||
func DeleteDisabledChannel() (int64, error) {
|
||
result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{})
|
||
return result.RowsAffected, result.Error
|
||
}
|
||
|
||
func GetPaginatedTags(offset int, limit int) ([]*string, error) {
|
||
var tags []*string
|
||
err := DB.Model(&Channel{}).Select("DISTINCT tag").Where("tag != ''").Offset(offset).Limit(limit).Find(&tags).Error
|
||
return tags, err
|
||
}
|
||
|
||
func SearchTags(keyword string, group string, model string, idSort bool) ([]*string, error) {
|
||
var tags []*string
|
||
modelsCol := "`models`"
|
||
|
||
// 如果是 PostgreSQL,使用双引号
|
||
if common.UsingPostgreSQL {
|
||
modelsCol = `"models"`
|
||
}
|
||
|
||
baseURLCol := "`base_url`"
|
||
// 如果是 PostgreSQL,使用双引号
|
||
if common.UsingPostgreSQL {
|
||
baseURLCol = `"base_url"`
|
||
}
|
||
|
||
order := "priority desc"
|
||
if idSort {
|
||
order = "id desc"
|
||
}
|
||
|
||
// 构造基础查询
|
||
baseQuery := DB.Model(&Channel{}).Omit("key")
|
||
|
||
// 构造WHERE子句
|
||
var whereClause string
|
||
var args []interface{}
|
||
if group != "" && group != "null" {
|
||
var groupCondition string
|
||
if common.UsingMySQL {
|
||
groupCondition = `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ?`
|
||
} else {
|
||
// sqlite, PostgreSQL
|
||
groupCondition = `(',' || ` + commonGroupCol + ` || ',') LIKE ?`
|
||
}
|
||
whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
|
||
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%")
|
||
} else {
|
||
whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
|
||
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%")
|
||
}
|
||
|
||
subQuery := baseQuery.Where(whereClause, args...).
|
||
Select("tag").
|
||
Where("tag != ''").
|
||
Order(order)
|
||
|
||
err := DB.Table("(?) as sub", subQuery).
|
||
Select("DISTINCT tag").
|
||
Find(&tags).Error
|
||
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return tags, nil
|
||
}
|
||
|
||
func (channel *Channel) ValidateSettings() error {
|
||
channelParams := &dto.ChannelSettings{}
|
||
if channel.Setting != nil && *channel.Setting != "" {
|
||
err := common.Unmarshal([]byte(*channel.Setting), channelParams)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (channel *Channel) GetSetting() dto.ChannelSettings {
|
||
setting := dto.ChannelSettings{}
|
||
if channel.Setting != nil && *channel.Setting != "" {
|
||
err := common.Unmarshal([]byte(*channel.Setting), &setting)
|
||
if err != nil {
|
||
common.SysLog("failed to unmarshal setting: " + err.Error())
|
||
channel.Setting = nil // 清空设置以避免后续错误
|
||
_ = channel.Save() // 保存修改
|
||
}
|
||
}
|
||
return setting
|
||
}
|
||
|
||
func (channel *Channel) SetSetting(setting dto.ChannelSettings) {
|
||
settingBytes, err := common.Marshal(setting)
|
||
if err != nil {
|
||
common.SysLog("failed to marshal setting: " + err.Error())
|
||
return
|
||
}
|
||
channel.Setting = common.GetPointer[string](string(settingBytes))
|
||
}
|
||
|
||
func (channel *Channel) GetOtherSettings() dto.ChannelOtherSettings {
|
||
setting := dto.ChannelOtherSettings{}
|
||
if channel.OtherSettings != "" {
|
||
err := common.UnmarshalJsonStr(channel.OtherSettings, &setting)
|
||
if err != nil {
|
||
common.SysLog("failed to unmarshal setting: " + err.Error())
|
||
channel.OtherSettings = "{}" // 清空设置以避免后续错误
|
||
_ = channel.Save() // 保存修改
|
||
}
|
||
}
|
||
return setting
|
||
}
|
||
|
||
func (channel *Channel) SetOtherSettings(setting dto.ChannelOtherSettings) {
|
||
settingBytes, err := common.Marshal(setting)
|
||
if err != nil {
|
||
common.SysLog("failed to marshal setting: " + err.Error())
|
||
return
|
||
}
|
||
channel.OtherSettings = string(settingBytes)
|
||
}
|
||
|
||
func (channel *Channel) GetParamOverride() map[string]interface{} {
|
||
paramOverride := make(map[string]interface{})
|
||
if channel.ParamOverride != nil && *channel.ParamOverride != "" {
|
||
err := common.Unmarshal([]byte(*channel.ParamOverride), ¶mOverride)
|
||
if err != nil {
|
||
common.SysLog("failed to unmarshal param override: " + err.Error())
|
||
}
|
||
}
|
||
return paramOverride
|
||
}
|
||
|
||
func GetChannelsByIds(ids []int) ([]*Channel, error) {
|
||
var channels []*Channel
|
||
err := DB.Where("id in (?)", ids).Find(&channels).Error
|
||
return channels, err
|
||
}
|
||
|
||
func BatchSetChannelTag(ids []int, tag *string) error {
|
||
// 开启事务
|
||
tx := DB.Begin()
|
||
if tx.Error != nil {
|
||
return tx.Error
|
||
}
|
||
|
||
// 更新标签
|
||
err := tx.Model(&Channel{}).Where("id in (?)", ids).Update("tag", tag).Error
|
||
if err != nil {
|
||
tx.Rollback()
|
||
return err
|
||
}
|
||
|
||
// update ability status
|
||
channels, err := GetChannelsByIds(ids)
|
||
if err != nil {
|
||
tx.Rollback()
|
||
return err
|
||
}
|
||
|
||
for _, channel := range channels {
|
||
err = channel.UpdateAbilities(tx)
|
||
if err != nil {
|
||
tx.Rollback()
|
||
return err
|
||
}
|
||
}
|
||
|
||
// 提交事务
|
||
return tx.Commit().Error
|
||
}
|
||
|
||
// CountAllChannels returns total channels in DB
|
||
func CountAllChannels() (int64, error) {
|
||
var total int64
|
||
err := DB.Model(&Channel{}).Count(&total).Error
|
||
return total, err
|
||
}
|
||
|
||
// CountAllTags returns number of non-empty distinct tags
|
||
func CountAllTags() (int64, error) {
|
||
var total int64
|
||
err := DB.Model(&Channel{}).Where("tag is not null AND tag != ''").Distinct("tag").Count(&total).Error
|
||
return total, err
|
||
}
|
||
|
||
// Get channels of specified type with pagination
|
||
func GetChannelsByType(startIdx int, num int, idSort bool, channelType int) ([]*Channel, error) {
|
||
var channels []*Channel
|
||
order := "priority desc"
|
||
if idSort {
|
||
order = "id desc"
|
||
}
|
||
err := DB.Where("type = ?", channelType).Order(order).Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
|
||
return channels, err
|
||
}
|
||
|
||
// Count channels of specific type
|
||
func CountChannelsByType(channelType int) (int64, error) {
|
||
var count int64
|
||
err := DB.Model(&Channel{}).Where("type = ?", channelType).Count(&count).Error
|
||
return count, err
|
||
}
|
||
|
||
// Return map[type]count for all channels
|
||
func CountChannelsGroupByType() (map[int64]int64, error) {
|
||
type result struct {
|
||
Type int64 `gorm:"column:type"`
|
||
Count int64 `gorm:"column:count"`
|
||
}
|
||
var results []result
|
||
err := DB.Model(&Channel{}).Select("type, count(*) as count").Group("type").Find(&results).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
counts := make(map[int64]int64)
|
||
for _, r := range results {
|
||
counts[r.Type] = r.Count
|
||
}
|
||
return counts, nil
|
||
}
|