Merge branch 'alpha' into mutil_key_channel

# Conflicts:
#	controller/channel.go
#	docker-compose.yml
#	web/src/components/table/ChannelsTable.js
#	web/src/pages/Channel/EditChannel.js
This commit is contained in:
CaIon
2025-07-06 10:33:48 +08:00
231 changed files with 12575 additions and 6119 deletions

View File

@@ -21,7 +21,22 @@ type Ability struct {
Tag *string `json:"tag" gorm:"index"`
}
func GetGroupModels(group string) []string {
type AbilityWithChannel struct {
Ability
ChannelType int `json:"channel_type"`
}
func GetAllEnableAbilityWithChannels() ([]AbilityWithChannel, error) {
var abilities []AbilityWithChannel
err := DB.Table("abilities").
Select("abilities.*, channels.type as channel_type").
Joins("left join channels on abilities.channel_id = channels.id").
Where("abilities.enabled = ?", true).
Scan(&abilities).Error
return abilities, err
}
func GetGroupEnabledModels(group string) []string {
var models []string
// Find distinct models
DB.Table("abilities").Where(commonGroupCol+" = ? and enabled = ?", group, true).Distinct("model").Pluck("model", &models)
@@ -46,7 +61,7 @@ func getPriority(group string, model string, retry int) (int, error) {
var priorities []int
err := DB.Model(&Ability{}).
Select("DISTINCT(priority)").
Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, commonTrueVal).
Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, true).
Order("priority DESC"). // 按优先级降序排序
Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中
@@ -72,14 +87,14 @@ func getPriority(group string, model string, retry int) (int, error) {
}
func getChannelQuery(group string, model string, retry int) *gorm.DB {
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, commonTrueVal)
channelQuery := DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = (?)", group, model, commonTrueVal, maxPrioritySubQuery)
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, true)
channelQuery := DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = (?)", group, model, true, maxPrioritySubQuery)
if retry != 0 {
priority, err := getPriority(group, model, retry)
if err != nil {
common.SysError(fmt.Sprintf("Get priority failed: %s", err.Error()))
} else {
channelQuery = DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = ?", group, model, commonTrueVal, priority)
channelQuery = DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = ?", group, model, true, priority)
}
}

View File

@@ -5,10 +5,13 @@ import (
"fmt"
"math/rand"
"one-api/common"
"one-api/setting"
"sort"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
)
var group2model2channels map[string]map[string][]*Channel
@@ -75,7 +78,43 @@ func SyncChannelCache(frequency int) {
}
}
func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, model string, retry int) (*Channel, string, error) {
var channel *Channel
var err error
selectGroup := group
if group == "auto" {
if len(setting.AutoGroups) == 0 {
return nil, selectGroup, errors.New("auto groups is not enabled")
}
for _, autoGroup := range setting.AutoGroups {
if common.DebugEnabled {
println("autoGroup:", autoGroup)
}
channel, _ = getRandomSatisfiedChannel(autoGroup, model, retry)
if channel == nil {
continue
} else {
c.Set("auto_group", autoGroup)
selectGroup = autoGroup
if common.DebugEnabled {
println("selectGroup:", selectGroup)
}
break
}
}
} else {
channel, err = getRandomSatisfiedChannel(group, model, retry)
if err != nil {
return nil, group, err
}
}
if channel == nil {
return nil, group, errors.New("channel not found")
}
return channel, selectGroup, nil
}
func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
if strings.HasPrefix(model, "gpt-4-gizmo") {
model = "gpt-4-gizmo-*"
}

View File

@@ -617,3 +617,39 @@ func CountAllTags() (int64, error) {
err := DB.Model(&Channel{}).Where("tag is not null AND tag != ''").Distinct("tag").Count(&total).Error
return total, err
}
// Get channels of specified type with pagination
func GetChannelsByType(startIdx int, num int, idSort bool, channelType int) ([]*Channel, error) {
var channels []*Channel
order := "priority desc"
if idSort {
order = "id desc"
}
err := DB.Where("type = ?", channelType).Order(order).Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
return channels, err
}
// Count channels of specific type
func CountChannelsByType(channelType int) (int64, error) {
var count int64
err := DB.Model(&Channel{}).Where("type = ?", channelType).Count(&count).Error
return count, err
}
// Return map[type]count for all channels
func CountChannelsGroupByType() (map[int64]int64, error) {
type result struct {
Type int64 `gorm:"column:type"`
Count int64 `gorm:"column:count"`
}
var results []result
err := DB.Model(&Channel{}).Select("type, count(*) as count").Group("type").Find(&results).Error
if err != nil {
return nil, err
}
counts := make(map[int64]int64)
for _, r := range results {
counts[r.Type] = r.Count
}
return counts, nil
}

View File

@@ -46,6 +46,15 @@ func initCol() {
logGroupCol = commonGroupCol
logKeyCol = commonKeyCol
}
} else {
// LOG_SQL_DSN 为空时,日志数据库与主数据库相同
if common.UsingPostgreSQL {
logGroupCol = `"group"`
logKeyCol = `"key"`
} else {
logGroupCol = commonGroupCol
logKeyCol = commonKeyCol
}
}
// log sql type and database type
//common.SysLog("Using Log SQL Type: " + common.LogSqlType)

View File

@@ -14,6 +14,8 @@ type Midjourney struct {
StartTime int64 `json:"start_time" gorm:"index"`
FinishTime int64 `json:"finish_time" gorm:"index"`
ImageUrl string `json:"image_url"`
VideoUrl string `json:"video_url"`
VideoUrls string `json:"video_urls"`
Status string `json:"status" gorm:"type:varchar(20);index"`
Progress string `json:"progress" gorm:"type:varchar(30);index"`
FailReason string `json:"fail_reason"`

View File

@@ -5,6 +5,7 @@ import (
"one-api/setting"
"one-api/setting/config"
"one-api/setting/operation_setting"
"one-api/setting/ratio_setting"
"strconv"
"strings"
"time"
@@ -76,6 +77,9 @@ func InitOptionMap() {
common.OptionMap["MinTopUp"] = strconv.Itoa(setting.MinTopUp)
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
common.OptionMap["Chats"] = setting.Chats2JsonString()
common.OptionMap["AutoGroups"] = setting.AutoGroups2JsonString()
common.OptionMap["DefaultUseAutoGroup"] = strconv.FormatBool(setting.DefaultUseAutoGroup)
common.OptionMap["PayMethods"] = setting.PayMethods2JsonString()
common.OptionMap["GitHubClientId"] = ""
common.OptionMap["GitHubClientSecret"] = ""
common.OptionMap["TelegramBotToken"] = ""
@@ -94,13 +98,13 @@ func InitOptionMap() {
common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes)
common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount)
common.OptionMap["ModelRequestRateLimitGroup"] = setting.ModelRequestRateLimitGroup2JSONString()
common.OptionMap["ModelRatio"] = operation_setting.ModelRatio2JSONString()
common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString()
common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString()
common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
common.OptionMap["GroupGroupRatio"] = setting.GroupGroupRatio2JSONString()
common.OptionMap["ModelRatio"] = ratio_setting.ModelRatio2JSONString()
common.OptionMap["ModelPrice"] = ratio_setting.ModelPrice2JSONString()
common.OptionMap["CacheRatio"] = ratio_setting.CacheRatio2JSONString()
common.OptionMap["GroupRatio"] = ratio_setting.GroupRatio2JSONString()
common.OptionMap["GroupGroupRatio"] = ratio_setting.GroupGroupRatio2JSONString()
common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
common.OptionMap["CompletionRatio"] = operation_setting.CompletionRatio2JSONString()
common.OptionMap["CompletionRatio"] = ratio_setting.CompletionRatio2JSONString()
common.OptionMap["TopUpLink"] = common.TopUpLink
//common.OptionMap["ChatLink"] = common.ChatLink
//common.OptionMap["ChatLink2"] = common.ChatLink2
@@ -123,6 +127,7 @@ func InitOptionMap() {
common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
common.OptionMap["AutomaticDisableKeywords"] = operation_setting.AutomaticDisableKeywordsToString()
common.OptionMap["ExposeRatioEnabled"] = strconv.FormatBool(ratio_setting.IsExposeRatioEnabled())
// 自动添加所有注册的模型配置
modelConfigs := config.GlobalConfig.ExportAllConfigs()
@@ -192,7 +197,7 @@ func updateOptionMap(key string, value string) (err error) {
common.ImageDownloadPermission = intValue
}
}
if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" {
if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" || key == "DefaultUseAutoGroup" {
boolValue := value == "true"
switch key {
case "PasswordRegisterEnabled":
@@ -261,6 +266,10 @@ func updateOptionMap(key string, value string) (err error) {
common.SMTPSSLEnabled = boolValue
case "WorkerAllowHttpImageRequestEnabled":
setting.WorkerAllowHttpImageRequestEnabled = boolValue
case "DefaultUseAutoGroup":
setting.DefaultUseAutoGroup = boolValue
case "ExposeRatioEnabled":
ratio_setting.SetExposeRatioEnabled(boolValue)
}
}
switch key {
@@ -287,6 +296,8 @@ func updateOptionMap(key string, value string) (err error) {
setting.PayAddress = value
case "Chats":
err = setting.UpdateChatsByJsonString(value)
case "AutoGroups":
err = setting.UpdateAutoGroupsByJsonString(value)
case "CustomCallbackAddress":
setting.CustomCallbackAddress = value
case "EpayId":
@@ -352,19 +363,19 @@ func updateOptionMap(key string, value string) (err error) {
case "DataExportDefaultTime":
common.DataExportDefaultTime = value
case "ModelRatio":
err = operation_setting.UpdateModelRatioByJSONString(value)
err = ratio_setting.UpdateModelRatioByJSONString(value)
case "GroupRatio":
err = setting.UpdateGroupRatioByJSONString(value)
err = ratio_setting.UpdateGroupRatioByJSONString(value)
case "GroupGroupRatio":
err = setting.UpdateGroupGroupRatioByJSONString(value)
err = ratio_setting.UpdateGroupGroupRatioByJSONString(value)
case "UserUsableGroups":
err = setting.UpdateUserUsableGroupsByJSONString(value)
case "CompletionRatio":
err = operation_setting.UpdateCompletionRatioByJSONString(value)
err = ratio_setting.UpdateCompletionRatioByJSONString(value)
case "ModelPrice":
err = operation_setting.UpdateModelPriceByJSONString(value)
err = ratio_setting.UpdateModelPriceByJSONString(value)
case "CacheRatio":
err = operation_setting.UpdateCacheRatioByJSONString(value)
err = ratio_setting.UpdateCacheRatioByJSONString(value)
case "TopUpLink":
common.TopUpLink = value
//case "ChatLink":
@@ -381,6 +392,8 @@ func updateOptionMap(key string, value string) (err error) {
operation_setting.AutomaticDisableKeywordsFromString(value)
case "StreamCacheQueueLength":
setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
case "PayMethods":
err = setting.UpdatePayMethodsByJsonString(value)
}
return err
}

View File

@@ -1,20 +1,24 @@
package model
import (
"fmt"
"one-api/common"
"one-api/setting/operation_setting"
"one-api/constant"
"one-api/setting/ratio_setting"
"one-api/types"
"sync"
"time"
)
type Pricing struct {
ModelName string `json:"model_name"`
QuotaType int `json:"quota_type"`
ModelRatio float64 `json:"model_ratio"`
ModelPrice float64 `json:"model_price"`
OwnerBy string `json:"owner_by"`
CompletionRatio float64 `json:"completion_ratio"`
EnableGroup []string `json:"enable_groups,omitempty"`
ModelName string `json:"model_name"`
QuotaType int `json:"quota_type"`
ModelRatio float64 `json:"model_ratio"`
ModelPrice float64 `json:"model_price"`
OwnerBy string `json:"owner_by"`
CompletionRatio float64 `json:"completion_ratio"`
EnableGroup []string `json:"enable_groups"`
SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
}
var (
@@ -23,56 +27,98 @@ var (
updatePricingLock sync.Mutex
)
func GetPricing() []Pricing {
updatePricingLock.Lock()
defer updatePricingLock.Unlock()
var (
modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
modelSupportEndpointsLock = sync.RWMutex{}
)
func GetPricing() []Pricing {
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
updatePricing()
updatePricingLock.Lock()
defer updatePricingLock.Unlock()
// Double check after acquiring the lock
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
modelSupportEndpointsLock.Lock()
defer modelSupportEndpointsLock.Unlock()
updatePricing()
}
}
//if group != "" {
// userPricingMap := make([]Pricing, 0)
// models := GetGroupModels(group)
// for _, pricing := range pricingMap {
// if !common.StringsContains(models, pricing.ModelName) {
// pricing.Available = false
// }
// userPricingMap = append(userPricingMap, pricing)
// }
// return userPricingMap
//}
return pricingMap
}
func GetModelSupportEndpointTypes(model string) []constant.EndpointType {
if model == "" {
return make([]constant.EndpointType, 0)
}
modelSupportEndpointsLock.RLock()
defer modelSupportEndpointsLock.RUnlock()
if endpoints, ok := modelSupportEndpointTypes[model]; ok {
return endpoints
}
return make([]constant.EndpointType, 0)
}
func updatePricing() {
//modelRatios := common.GetModelRatios()
enableAbilities := GetAllEnableAbilities()
modelGroupsMap := make(map[string][]string)
enableAbilities, err := GetAllEnableAbilityWithChannels()
if err != nil {
common.SysError(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err))
return
}
modelGroupsMap := make(map[string]*types.Set[string])
for _, ability := range enableAbilities {
groups := modelGroupsMap[ability.Model]
if groups == nil {
groups = make([]string, 0)
groups, ok := modelGroupsMap[ability.Model]
if !ok {
groups = types.NewSet[string]()
modelGroupsMap[ability.Model] = groups
}
if !common.StringsContains(groups, ability.Group) {
groups = append(groups, ability.Group)
groups.Add(ability.Group)
}
//这里使用切片而不是Set因为一个模型可能支持多个端点类型并且第一个端点是优先使用端点
modelSupportEndpointsStr := make(map[string][]string)
for _, ability := range enableAbilities {
endpoints, ok := modelSupportEndpointsStr[ability.Model]
if !ok {
endpoints = make([]string, 0)
modelSupportEndpointsStr[ability.Model] = endpoints
}
modelGroupsMap[ability.Model] = groups
channelTypes := common.GetEndpointTypesByChannelType(ability.ChannelType, ability.Model)
for _, channelType := range channelTypes {
if !common.StringsContains(endpoints, string(channelType)) {
endpoints = append(endpoints, string(channelType))
}
}
modelSupportEndpointsStr[ability.Model] = endpoints
}
modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
for model, endpoints := range modelSupportEndpointsStr {
supportedEndpoints := make([]constant.EndpointType, 0)
for _, endpointStr := range endpoints {
endpointType := constant.EndpointType(endpointStr)
supportedEndpoints = append(supportedEndpoints, endpointType)
}
modelSupportEndpointTypes[model] = supportedEndpoints
}
pricingMap = make([]Pricing, 0)
for model, groups := range modelGroupsMap {
pricing := Pricing{
ModelName: model,
EnableGroup: groups,
ModelName: model,
EnableGroup: groups.Items(),
SupportedEndpointTypes: modelSupportEndpointTypes[model],
}
modelPrice, findPrice := operation_setting.GetModelPrice(model, false)
modelPrice, findPrice := ratio_setting.GetModelPrice(model, false)
if findPrice {
pricing.ModelPrice = modelPrice
pricing.QuotaType = 1
} else {
modelRatio, _ := operation_setting.GetModelRatio(model)
modelRatio, _ := ratio_setting.GetModelRatio(model)
pricing.ModelRatio = modelRatio
pricing.CompletionRatio = operation_setting.GetCompletionRatio(model)
pricing.CompletionRatio = ratio_setting.GetCompletionRatio(model)
pricing.QuotaType = 0
}
pricingMap = append(pricingMap, pricing)

View File

@@ -327,3 +327,37 @@ func CountUserTokens(userId int) (int64, error) {
err := DB.Model(&Token{}).Where("user_id = ?", userId).Count(&total).Error
return total, err
}
// BatchDeleteTokens 删除指定用户的一组令牌,返回成功删除数量
func BatchDeleteTokens(ids []int, userId int) (int, error) {
if len(ids) == 0 {
return 0, errors.New("ids 不能为空!")
}
tx := DB.Begin()
var tokens []Token
if err := tx.Where("user_id = ? AND id IN (?)", userId, ids).Find(&tokens).Error; err != nil {
tx.Rollback()
return 0, err
}
if err := tx.Where("user_id = ? AND id IN (?)", userId, ids).Delete(&Token{}).Error; err != nil {
tx.Rollback()
return 0, err
}
if err := tx.Commit().Error; err != nil {
return 0, err
}
if common.RedisEnabled {
gopool.Go(func() {
for _, t := range tokens {
_ = cacheDeleteToken(t.Key)
}
})
}
return len(tokens), nil
}

View File

@@ -10,7 +10,7 @@ import (
func cacheSetToken(token Token) error {
key := common.GenerateHMAC(token.Key)
token.Clean()
err := common.RedisHSetObj(fmt.Sprintf("token:%s", key), &token, time.Duration(constant.TokenCacheSeconds)*time.Second)
err := common.RedisHSetObj(fmt.Sprintf("token:%s", key), &token, time.Duration(common.RedisKeyCacheSeconds())*time.Second)
if err != nil {
return err
}

View File

@@ -41,6 +41,7 @@ type User struct {
DeletedAt gorm.DeletedAt `gorm:"index"`
LinuxDOId string `json:"linux_do_id" gorm:"column:linux_do_id;index"`
Setting string `json:"setting" gorm:"type:text;column:setting"`
Remark string `json:"remark,omitempty" gorm:"type:varchar(255)" validate:"max=255"`
}
func (user *User) ToBaseUser() *UserBase {
@@ -113,7 +114,7 @@ func GetMaxUserId() int {
return user.Id
}
func GetAllUsers(startIdx int, num int) (users []*User, total int64, err error) {
func GetAllUsers(pageInfo *common.PageInfo) (users []*User, total int64, err error) {
// Start transaction
tx := DB.Begin()
if tx.Error != nil {
@@ -133,7 +134,7 @@ func GetAllUsers(startIdx int, num int) (users []*User, total int64, err error)
}
// Get paginated users within same transaction
err = tx.Unscoped().Order("id desc").Limit(num).Offset(startIdx).Omit("password").Find(&users).Error
err = tx.Unscoped().Order("id desc").Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Omit("password").Find(&users).Error
if err != nil {
tx.Rollback()
return nil, 0, err
@@ -366,6 +367,7 @@ func (user *User) Edit(updatePassword bool) error {
"display_name": newUser.DisplayName,
"group": newUser.Group,
"quota": newUser.Quota,
"remark": newUser.Remark,
}
if updatePassword {
updates["password"] = newUser.Password

View File

@@ -24,12 +24,12 @@ type UserBase struct {
}
func (user *UserBase) WriteContext(c *gin.Context) {
c.Set(constant.ContextKeyUserGroup, user.Group)
c.Set(constant.ContextKeyUserQuota, user.Quota)
c.Set(constant.ContextKeyUserStatus, user.Status)
c.Set(constant.ContextKeyUserEmail, user.Email)
c.Set("username", user.Username)
c.Set(constant.ContextKeyUserSetting, user.GetSetting())
common.SetContextKey(c, constant.ContextKeyUserGroup, user.Group)
common.SetContextKey(c, constant.ContextKeyUserQuota, user.Quota)
common.SetContextKey(c, constant.ContextKeyUserStatus, user.Status)
common.SetContextKey(c, constant.ContextKeyUserEmail, user.Email)
common.SetContextKey(c, constant.ContextKeyUserName, user.Username)
common.SetContextKey(c, constant.ContextKeyUserSetting, user.GetSetting())
}
func (user *UserBase) GetSetting() map[string]interface{} {
@@ -70,7 +70,7 @@ func updateUserCache(user User) error {
return common.RedisHSetObj(
getUserCacheKey(user.Id),
user.ToBaseUser(),
time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
time.Duration(common.RedisKeyCacheSeconds())*time.Second,
)
}

View File

@@ -2,11 +2,12 @@ package model
import (
"errors"
"github.com/bytedance/gopkg/util/gopool"
"gorm.io/gorm"
"one-api/common"
"sync"
"time"
"github.com/bytedance/gopkg/util/gopool"
"gorm.io/gorm"
)
const (
@@ -48,6 +49,22 @@ func addNewRecord(type_ int, id int, value int) {
}
func batchUpdate() {
// check if there's any data to update
hasData := false
for i := 0; i < BatchUpdateTypeCount; i++ {
batchUpdateLocks[i].Lock()
if len(batchUpdateStores[i]) > 0 {
hasData = true
batchUpdateLocks[i].Unlock()
break
}
batchUpdateLocks[i].Unlock()
}
if !hasData {
return
}
common.SysLog("batch update started")
for i := 0; i < BatchUpdateTypeCount; i++ {
batchUpdateLocks[i].Lock()