From 29ec328f464ff66a3a911aa6e1a9aa83029457ec Mon Sep 17 00:00:00 2001 From: CaIon Date: Fri, 8 Aug 2025 11:59:04 +0800 Subject: [PATCH] fix: playground group --- controller/playground.go | 30 +++--------------------------- middleware/distributor.go | 17 +++++++++++++++++ 2 files changed, 20 insertions(+), 27 deletions(-) diff --git a/controller/playground.go b/controller/playground.go index 64c0e1ce..dd930802 100644 --- a/controller/playground.go +++ b/controller/playground.go @@ -5,10 +5,8 @@ import ( "fmt" "one-api/common" "one-api/constant" - "one-api/dto" "one-api/middleware" "one-api/model" - "one-api/setting" "one-api/types" "time" @@ -32,30 +30,8 @@ func Playground(c *gin.Context) { return } - playgroundRequest := &dto.PlayGroundRequest{} - err := common.UnmarshalBodyReusable(c, playgroundRequest) - if err != nil { - newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) - return - } - - if playgroundRequest.Model == "" { - newAPIError = types.NewError(errors.New("请选择模型"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) - return - } - c.Set("original_model", playgroundRequest.Model) - group := playgroundRequest.Group - userGroup := c.GetString("group") - - if group == "" { - group = userGroup - } else { - if !setting.GroupInUserUsableGroups(group) && group != userGroup { - newAPIError = types.NewError(errors.New("无权访问该分组"), types.ErrorCodeAccessDenied, types.ErrOptionWithSkipRetry()) - return - } - c.Set("group", group) - } + group := c.GetString("group") + modelName := c.GetString("original_model") userId := c.GetInt("id") @@ -73,7 +49,7 @@ func Playground(c *gin.Context) { Group: group, } _ = middleware.SetupContextForToken(c, tempToken) - _, newAPIError = getChannel(c, group, playgroundRequest.Model, 0) + _, newAPIError = getChannel(c, group, modelName, 0) if newAPIError != nil { return } diff --git a/middleware/distributor.go b/middleware/distributor.go index 5fae6322..e8abcbe9 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -10,6 +10,7 @@ import ( "one-api/model" relayconstant "one-api/relay/constant" "one-api/service" + "one-api/setting" "one-api/setting/ratio_setting" "one-api/types" "strconv" @@ -78,6 +79,22 @@ func Distribute() func(c *gin.Context) { } var selectGroup string userGroup := common.GetContextKeyString(c, constant.ContextKeyUsingGroup) + // check path is /pg/chat/completions + if strings.HasPrefix(c.Request.URL.Path, "/pg/chat/completions") { + playgroundRequest := &dto.PlayGroundRequest{} + err = common.UnmarshalBodyReusable(c, playgroundRequest) + if err != nil { + abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error()) + return + } + if playgroundRequest.Group != "" { + if !setting.GroupInUserUsableGroups(playgroundRequest.Group) && playgroundRequest.Group != userGroup { + abortWithOpenAiMessage(c, http.StatusForbidden, "无权访问该分组") + return + } + userGroup = playgroundRequest.Group + } + } channel, selectGroup, err = model.CacheGetRandomSatisfiedChannel(c, userGroup, modelRequest.Model, 0) if err != nil { showGroup := userGroup