diff --git a/controller/group.go b/controller/group.go index 11a0a3d3..6ba339a3 100644 --- a/controller/group.go +++ b/controller/group.go @@ -29,11 +29,11 @@ func GetUserGroups(c *gin.Context) { userId := c.GetInt("id") userGroup, _ = model.GetUserGroup(userId, false) userUsableGroups := service.GetUserUsableGroups(userGroup) - for groupName, ratio := range ratio_setting.GetGroupRatioCopy() { + for groupName, _ := range ratio_setting.GetGroupRatioCopy() { // UserUsableGroups contains the groups that the user can use if desc, ok := userUsableGroups[groupName]; ok { usableGroups[groupName] = map[string]interface{}{ - "ratio": ratio, + "ratio": service.GetUserGroupRatio(userGroup, groupName), "desc": desc, } } diff --git a/controller/playground.go b/controller/playground.go index f6e0953f..342f47cf 100644 --- a/controller/playground.go +++ b/controller/playground.go @@ -31,7 +31,7 @@ func Playground(c *gin.Context) { return } - group := c.GetString("group") + group := common.GetContextKeyString(c, constant.ContextKeyUsingGroup) modelName := c.GetString("original_model") userId := c.GetInt("id") diff --git a/middleware/distributor.go b/middleware/distributor.go index fabcac20..5a9deb23 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -94,6 +94,7 @@ func Distribute() func(c *gin.Context) { return } usingGroup = playgroundRequest.Group + common.SetContextKey(c, constant.ContextKeyUsingGroup, usingGroup) } } channel, selectGroup, err = service.CacheGetRandomSatisfiedChannel(c, usingGroup, modelRequest.Model, 0) diff --git a/service/group.go b/service/group.go index 0d8a7037..a73642c3 100644 --- a/service/group.go +++ b/service/group.go @@ -52,3 +52,14 @@ func GetUserAutoGroup(userGroup string) []string { } return autoGroups } + +// GetUserGroupRatio 获取用户使用某个分组的倍率 +// userGroup 用户分组 +// group 需要获取倍率的分组 +func GetUserGroupRatio(userGroup, group string) float64 { + ratio, ok := ratio_setting.GetGroupGroupRatio(userGroup, group) + if ok { + return ratio + } + return ratio_setting.GetGroupRatio(group) +}