✨ feat: refactor environment variable initialization and introduce new constant types for API and context keys
This commit is contained in:
@@ -21,7 +21,22 @@ type Ability struct {
|
||||
Tag *string `json:"tag" gorm:"index"`
|
||||
}
|
||||
|
||||
func GetGroupModels(group string) []string {
|
||||
type AbilityWithChannel struct {
|
||||
Ability
|
||||
ChannelType int `json:"channel_type"`
|
||||
}
|
||||
|
||||
func GetAllEnableAbilityWithChannels() ([]AbilityWithChannel, error) {
|
||||
var abilities []AbilityWithChannel
|
||||
err := DB.Table("abilities").
|
||||
Select("abilities.*, channels.type as channel_type").
|
||||
Joins("left join channels on abilities.channel_id = channels.id").
|
||||
Where("abilities.enabled = ?", true).
|
||||
Scan(&abilities).Error
|
||||
return abilities, err
|
||||
}
|
||||
|
||||
func GetGroupEnabledModels(group string) []string {
|
||||
var models []string
|
||||
// Find distinct models
|
||||
DB.Table("abilities").Where(commonGroupCol+" = ? and enabled = ?", group, true).Distinct("model").Pluck("model", &models)
|
||||
@@ -46,7 +61,7 @@ func getPriority(group string, model string, retry int) (int, error) {
|
||||
var priorities []int
|
||||
err := DB.Model(&Ability{}).
|
||||
Select("DISTINCT(priority)").
|
||||
Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, commonTrueVal).
|
||||
Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, true).
|
||||
Order("priority DESC"). // 按优先级降序排序
|
||||
Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中
|
||||
|
||||
@@ -72,14 +87,14 @@ func getPriority(group string, model string, retry int) (int, error) {
|
||||
}
|
||||
|
||||
func getChannelQuery(group string, model string, retry int) *gorm.DB {
|
||||
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, commonTrueVal)
|
||||
channelQuery := DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = (?)", group, model, commonTrueVal, maxPrioritySubQuery)
|
||||
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, true)
|
||||
channelQuery := DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = (?)", group, model, true, maxPrioritySubQuery)
|
||||
if retry != 0 {
|
||||
priority, err := getPriority(group, model, retry)
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Get priority failed: %s", err.Error()))
|
||||
} else {
|
||||
channelQuery = DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = ?", group, model, commonTrueVal, priority)
|
||||
channelQuery = DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = ?", group, model, true, priority)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
110
model/pricing.go
110
model/pricing.go
@@ -1,20 +1,24 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/setting/ratio_setting"
|
||||
"one-api/types"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Pricing struct {
|
||||
ModelName string `json:"model_name"`
|
||||
QuotaType int `json:"quota_type"`
|
||||
ModelRatio float64 `json:"model_ratio"`
|
||||
ModelPrice float64 `json:"model_price"`
|
||||
OwnerBy string `json:"owner_by"`
|
||||
CompletionRatio float64 `json:"completion_ratio"`
|
||||
EnableGroup []string `json:"enable_groups,omitempty"`
|
||||
ModelName string `json:"model_name"`
|
||||
QuotaType int `json:"quota_type"`
|
||||
ModelRatio float64 `json:"model_ratio"`
|
||||
ModelPrice float64 `json:"model_price"`
|
||||
OwnerBy string `json:"owner_by"`
|
||||
CompletionRatio float64 `json:"completion_ratio"`
|
||||
EnableGroup []string `json:"enable_groups"`
|
||||
SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -23,47 +27,89 @@ var (
|
||||
updatePricingLock sync.Mutex
|
||||
)
|
||||
|
||||
func GetPricing() []Pricing {
|
||||
updatePricingLock.Lock()
|
||||
defer updatePricingLock.Unlock()
|
||||
var (
|
||||
modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
|
||||
modelSupportEndpointsLock = sync.RWMutex{}
|
||||
)
|
||||
|
||||
func GetPricing() []Pricing {
|
||||
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
|
||||
updatePricing()
|
||||
updatePricingLock.Lock()
|
||||
defer updatePricingLock.Unlock()
|
||||
// Double check after acquiring the lock
|
||||
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
|
||||
modelSupportEndpointsLock.Lock()
|
||||
defer modelSupportEndpointsLock.Unlock()
|
||||
updatePricing()
|
||||
}
|
||||
}
|
||||
//if group != "" {
|
||||
// userPricingMap := make([]Pricing, 0)
|
||||
// models := GetGroupModels(group)
|
||||
// for _, pricing := range pricingMap {
|
||||
// if !common.StringsContains(models, pricing.ModelName) {
|
||||
// pricing.Available = false
|
||||
// }
|
||||
// userPricingMap = append(userPricingMap, pricing)
|
||||
// }
|
||||
// return userPricingMap
|
||||
//}
|
||||
return pricingMap
|
||||
}
|
||||
|
||||
func GetModelSupportEndpointTypes(model string) []constant.EndpointType {
|
||||
if model == "" {
|
||||
return make([]constant.EndpointType, 0)
|
||||
}
|
||||
modelSupportEndpointsLock.RLock()
|
||||
defer modelSupportEndpointsLock.RUnlock()
|
||||
if endpoints, ok := modelSupportEndpointTypes[model]; ok {
|
||||
return endpoints
|
||||
}
|
||||
return make([]constant.EndpointType, 0)
|
||||
}
|
||||
|
||||
func updatePricing() {
|
||||
//modelRatios := common.GetModelRatios()
|
||||
enableAbilities := GetAllEnableAbilities()
|
||||
modelGroupsMap := make(map[string][]string)
|
||||
enableAbilities, err := GetAllEnableAbilityWithChannels()
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err))
|
||||
return
|
||||
}
|
||||
modelGroupsMap := make(map[string]*types.Set[string])
|
||||
|
||||
for _, ability := range enableAbilities {
|
||||
groups := modelGroupsMap[ability.Model]
|
||||
if groups == nil {
|
||||
groups = make([]string, 0)
|
||||
groups, ok := modelGroupsMap[ability.Model]
|
||||
if !ok {
|
||||
groups = types.NewSet[string]()
|
||||
modelGroupsMap[ability.Model] = groups
|
||||
}
|
||||
if !common.StringsContains(groups, ability.Group) {
|
||||
groups = append(groups, ability.Group)
|
||||
groups.Add(ability.Group)
|
||||
}
|
||||
|
||||
//这里使用切片而不是Set,因为一个模型可能支持多个端点类型,并且第一个端点是优先使用端点
|
||||
modelSupportEndpointsStr := make(map[string][]string)
|
||||
|
||||
for _, ability := range enableAbilities {
|
||||
endpoints, ok := modelSupportEndpointsStr[ability.Model]
|
||||
if !ok {
|
||||
endpoints = make([]string, 0)
|
||||
modelSupportEndpointsStr[ability.Model] = endpoints
|
||||
}
|
||||
modelGroupsMap[ability.Model] = groups
|
||||
channelTypes := common.GetEndpointTypesByChannelType(ability.ChannelType, ability.Model)
|
||||
for _, channelType := range channelTypes {
|
||||
if !common.StringsContains(endpoints, string(channelType)) {
|
||||
endpoints = append(endpoints, string(channelType))
|
||||
}
|
||||
}
|
||||
modelSupportEndpointsStr[ability.Model] = endpoints
|
||||
}
|
||||
|
||||
modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
|
||||
for model, endpoints := range modelSupportEndpointsStr {
|
||||
supportedEndpoints := make([]constant.EndpointType, 0)
|
||||
for _, endpointStr := range endpoints {
|
||||
endpointType := constant.EndpointType(endpointStr)
|
||||
supportedEndpoints = append(supportedEndpoints, endpointType)
|
||||
}
|
||||
modelSupportEndpointTypes[model] = supportedEndpoints
|
||||
}
|
||||
|
||||
pricingMap = make([]Pricing, 0)
|
||||
for model, groups := range modelGroupsMap {
|
||||
pricing := Pricing{
|
||||
ModelName: model,
|
||||
EnableGroup: groups,
|
||||
ModelName: model,
|
||||
EnableGroup: groups.Items(),
|
||||
SupportedEndpointTypes: modelSupportEndpointTypes[model],
|
||||
}
|
||||
modelPrice, findPrice := ratio_setting.GetModelPrice(model, false)
|
||||
if findPrice {
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
func cacheSetToken(token Token) error {
|
||||
key := common.GenerateHMAC(token.Key)
|
||||
token.Clean()
|
||||
err := common.RedisHSetObj(fmt.Sprintf("token:%s", key), &token, time.Duration(constant.RedisKeyCacheSeconds())*time.Second)
|
||||
err := common.RedisHSetObj(fmt.Sprintf("token:%s", key), &token, time.Duration(common.RedisKeyCacheSeconds())*time.Second)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -24,12 +24,12 @@ type UserBase struct {
|
||||
}
|
||||
|
||||
func (user *UserBase) WriteContext(c *gin.Context) {
|
||||
c.Set(constant.ContextKeyUserGroup, user.Group)
|
||||
c.Set(constant.ContextKeyUserQuota, user.Quota)
|
||||
c.Set(constant.ContextKeyUserStatus, user.Status)
|
||||
c.Set(constant.ContextKeyUserEmail, user.Email)
|
||||
c.Set("username", user.Username)
|
||||
c.Set(constant.ContextKeyUserSetting, user.GetSetting())
|
||||
common.SetContextKey(c, constant.ContextKeyUserGroup, user.Group)
|
||||
common.SetContextKey(c, constant.ContextKeyUserQuota, user.Quota)
|
||||
common.SetContextKey(c, constant.ContextKeyUserStatus, user.Status)
|
||||
common.SetContextKey(c, constant.ContextKeyUserEmail, user.Email)
|
||||
common.SetContextKey(c, constant.ContextKeyUserName, user.Username)
|
||||
common.SetContextKey(c, constant.ContextKeyUserSetting, user.GetSetting())
|
||||
}
|
||||
|
||||
func (user *UserBase) GetSetting() map[string]interface{} {
|
||||
@@ -70,7 +70,7 @@ func updateUserCache(user User) error {
|
||||
return common.RedisHSetObj(
|
||||
getUserCacheKey(user.Id),
|
||||
user.ToBaseUser(),
|
||||
time.Duration(constant.RedisKeyCacheSeconds())*time.Second,
|
||||
time.Duration(common.RedisKeyCacheSeconds())*time.Second,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user