diff --git a/common/model-ratio.go b/common/model-ratio.go index b45ed293..3713c2c4 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -375,6 +375,9 @@ func GetCompletionRatio(name string) float64 { return 3 } if strings.HasPrefix(name, "gemini-") { + if strings.Contains(name, "flash") { + return 4 + } return 3 } if strings.HasPrefix(name, "command") { diff --git a/controller/model.go b/controller/model.go index 36beb2d1..3d207023 100644 --- a/controller/model.go +++ b/controller/model.go @@ -137,15 +137,6 @@ func init() { } func ListModels(c *gin.Context) { - userId := c.GetInt("id") - user, err := model.GetUserById(userId, true) - if err != nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": err.Error(), - }) - return - } userOpenAiModels := make([]dto.OpenAIModels, 0) permission := getPermission() @@ -174,7 +165,21 @@ func ListModels(c *gin.Context) { } } } else { - models := model.GetGroupModels(user.Group) + userId := c.GetInt("id") + userGroup, err := model.GetUserGroup(userId) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "get user group failed", + }) + return + } + group := userGroup + tokenGroup := c.GetString("token_group") + if tokenGroup != "" { + group = tokenGroup + } + models := model.GetGroupModels(group) for _, s := range models { if _, ok := openAIModelsMap[s]; ok { userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s]) diff --git a/middleware/distributor.go b/middleware/distributor.go index 0393d24f..7a959649 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -41,9 +41,14 @@ func Distribute() func(c *gin.Context) { userGroup, _ := model.CacheGetUserGroup(userId) tokenGroup := c.GetString("token_group") if tokenGroup != "" { + // check common.UserUsableGroups[userGroup] + if _, ok := common.UserUsableGroups[tokenGroup]; !ok { + abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("令牌分组 %s 已被禁用", tokenGroup)) + return + } // check group in common.GroupRatio if _, ok := common.GroupRatio[tokenGroup]; !ok { - abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被禁用", tokenGroup)) + abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup)) return } userGroup = tokenGroup