fix: playground group
This commit is contained in:
@@ -5,10 +5,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
|
||||||
"one-api/middleware"
|
"one-api/middleware"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/setting"
|
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -32,30 +30,8 @@ func Playground(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
playgroundRequest := &dto.PlayGroundRequest{}
|
group := c.GetString("group")
|
||||||
err := common.UnmarshalBodyReusable(c, playgroundRequest)
|
modelName := c.GetString("original_model")
|
||||||
if err != nil {
|
|
||||||
newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if playgroundRequest.Model == "" {
|
|
||||||
newAPIError = types.NewError(errors.New("请选择模型"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.Set("original_model", playgroundRequest.Model)
|
|
||||||
group := playgroundRequest.Group
|
|
||||||
userGroup := c.GetString("group")
|
|
||||||
|
|
||||||
if group == "" {
|
|
||||||
group = userGroup
|
|
||||||
} else {
|
|
||||||
if !setting.GroupInUserUsableGroups(group) && group != userGroup {
|
|
||||||
newAPIError = types.NewError(errors.New("无权访问该分组"), types.ErrorCodeAccessDenied, types.ErrOptionWithSkipRetry())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.Set("group", group)
|
|
||||||
}
|
|
||||||
|
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
|
|
||||||
@@ -73,7 +49,7 @@ func Playground(c *gin.Context) {
|
|||||||
Group: group,
|
Group: group,
|
||||||
}
|
}
|
||||||
_ = middleware.SetupContextForToken(c, tempToken)
|
_ = middleware.SetupContextForToken(c, tempToken)
|
||||||
_, newAPIError = getChannel(c, group, playgroundRequest.Model, 0)
|
_, newAPIError = getChannel(c, group, modelName, 0)
|
||||||
if newAPIError != nil {
|
if newAPIError != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"one-api/model"
|
"one-api/model"
|
||||||
relayconstant "one-api/relay/constant"
|
relayconstant "one-api/relay/constant"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
|
"one-api/setting"
|
||||||
"one-api/setting/ratio_setting"
|
"one-api/setting/ratio_setting"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -78,6 +79,22 @@ func Distribute() func(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
var selectGroup string
|
var selectGroup string
|
||||||
userGroup := common.GetContextKeyString(c, constant.ContextKeyUsingGroup)
|
userGroup := common.GetContextKeyString(c, constant.ContextKeyUsingGroup)
|
||||||
|
// check path is /pg/chat/completions
|
||||||
|
if strings.HasPrefix(c.Request.URL.Path, "/pg/chat/completions") {
|
||||||
|
playgroundRequest := &dto.PlayGroundRequest{}
|
||||||
|
err = common.UnmarshalBodyReusable(c, playgroundRequest)
|
||||||
|
if err != nil {
|
||||||
|
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if playgroundRequest.Group != "" {
|
||||||
|
if !setting.GroupInUserUsableGroups(playgroundRequest.Group) && playgroundRequest.Group != userGroup {
|
||||||
|
abortWithOpenAiMessage(c, http.StatusForbidden, "无权访问该分组")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
userGroup = playgroundRequest.Group
|
||||||
|
}
|
||||||
|
}
|
||||||
channel, selectGroup, err = model.CacheGetRandomSatisfiedChannel(c, userGroup, modelRequest.Model, 0)
|
channel, selectGroup, err = model.CacheGetRandomSatisfiedChannel(c, userGroup, modelRequest.Model, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
showGroup := userGroup
|
showGroup := userGroup
|
||||||
|
|||||||
Reference in New Issue
Block a user