fix: playground group

This commit is contained in:
CaIon
2025-08-08 11:59:04 +08:00
parent b843bb8286
commit 29ec328f46
2 changed files with 20 additions and 27 deletions

View File

@@ -5,10 +5,8 @@ import (
"fmt" "fmt"
"one-api/common" "one-api/common"
"one-api/constant" "one-api/constant"
"one-api/dto"
"one-api/middleware" "one-api/middleware"
"one-api/model" "one-api/model"
"one-api/setting"
"one-api/types" "one-api/types"
"time" "time"
@@ -32,30 +30,8 @@ func Playground(c *gin.Context) {
return return
} }
playgroundRequest := &dto.PlayGroundRequest{} group := c.GetString("group")
err := common.UnmarshalBodyReusable(c, playgroundRequest) modelName := c.GetString("original_model")
if err != nil {
newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
return
}
if playgroundRequest.Model == "" {
newAPIError = types.NewError(errors.New("请选择模型"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
return
}
c.Set("original_model", playgroundRequest.Model)
group := playgroundRequest.Group
userGroup := c.GetString("group")
if group == "" {
group = userGroup
} else {
if !setting.GroupInUserUsableGroups(group) && group != userGroup {
newAPIError = types.NewError(errors.New("无权访问该分组"), types.ErrorCodeAccessDenied, types.ErrOptionWithSkipRetry())
return
}
c.Set("group", group)
}
userId := c.GetInt("id") userId := c.GetInt("id")
@@ -73,7 +49,7 @@ func Playground(c *gin.Context) {
Group: group, Group: group,
} }
_ = middleware.SetupContextForToken(c, tempToken) _ = middleware.SetupContextForToken(c, tempToken)
_, newAPIError = getChannel(c, group, playgroundRequest.Model, 0) _, newAPIError = getChannel(c, group, modelName, 0)
if newAPIError != nil { if newAPIError != nil {
return return
} }

View File

@@ -10,6 +10,7 @@ import (
"one-api/model" "one-api/model"
relayconstant "one-api/relay/constant" relayconstant "one-api/relay/constant"
"one-api/service" "one-api/service"
"one-api/setting"
"one-api/setting/ratio_setting" "one-api/setting/ratio_setting"
"one-api/types" "one-api/types"
"strconv" "strconv"
@@ -78,6 +79,22 @@ func Distribute() func(c *gin.Context) {
} }
var selectGroup string var selectGroup string
userGroup := common.GetContextKeyString(c, constant.ContextKeyUsingGroup) userGroup := common.GetContextKeyString(c, constant.ContextKeyUsingGroup)
// check path is /pg/chat/completions
if strings.HasPrefix(c.Request.URL.Path, "/pg/chat/completions") {
playgroundRequest := &dto.PlayGroundRequest{}
err = common.UnmarshalBodyReusable(c, playgroundRequest)
if err != nil {
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
return
}
if playgroundRequest.Group != "" {
if !setting.GroupInUserUsableGroups(playgroundRequest.Group) && playgroundRequest.Group != userGroup {
abortWithOpenAiMessage(c, http.StatusForbidden, "无权访问该分组")
return
}
userGroup = playgroundRequest.Group
}
}
channel, selectGroup, err = model.CacheGetRandomSatisfiedChannel(c, userGroup, modelRequest.Model, 0) channel, selectGroup, err = model.CacheGetRandomSatisfiedChannel(c, userGroup, modelRequest.Model, 0)
if err != nil { if err != nil {
showGroup := userGroup showGroup := userGroup