From c8614f98906304038a4a2b4276b3293808a26f8f Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Sat, 28 Dec 2024 16:47:56 +0800 Subject: [PATCH] refactor: Playground controller --- controller/playground.go | 66 ++++++++++++++++++++++++++++++++++++++++ controller/relay.go | 53 -------------------------------- 2 files changed, 66 insertions(+), 53 deletions(-) create mode 100644 controller/playground.go diff --git a/controller/playground.go b/controller/playground.go new file mode 100644 index 00000000..2c81a1b6 --- /dev/null +++ b/controller/playground.go @@ -0,0 +1,66 @@ +package controller + +import ( + "errors" + "fmt" + "github.com/gin-gonic/gin" + "net/http" + "one-api/common" + "one-api/dto" + "one-api/middleware" + "one-api/model" + "one-api/service" + "one-api/setting" +) + +func Playground(c *gin.Context) { + var openaiErr *dto.OpenAIErrorWithStatusCode + + defer func() { + if openaiErr != nil { + c.JSON(openaiErr.StatusCode, gin.H{ + "error": openaiErr.Error, + }) + } + }() + + useAccessToken := c.GetBool("use_access_token") + if useAccessToken { + openaiErr = service.OpenAIErrorWrapperLocal(errors.New("暂不支持使用 access token"), "access_token_not_supported", http.StatusBadRequest) + return + } + + playgroundRequest := &dto.PlayGroundRequest{} + err := common.UnmarshalBodyReusable(c, playgroundRequest) + if err != nil { + openaiErr = service.OpenAIErrorWrapperLocal(err, "unmarshal_request_failed", http.StatusBadRequest) + return + } + + if playgroundRequest.Model == "" { + openaiErr = service.OpenAIErrorWrapperLocal(errors.New("请选择模型"), "model_required", http.StatusBadRequest) + return + } + c.Set("original_model", playgroundRequest.Model) + group := playgroundRequest.Group + userGroup := c.GetString("group") + + if group == "" { + group = userGroup + } else { + if !setting.GroupInUserUsableGroups(group) && group != userGroup { + openaiErr = service.OpenAIErrorWrapperLocal(errors.New("无权访问该分组"), "group_not_allowed", http.StatusForbidden) + return + } + c.Set("group", group) + } + c.Set("token_name", "playground-"+group) + channel, err := model.CacheGetRandomSatisfiedChannel(group, playgroundRequest.Model, 0) + if err != nil { + message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, playgroundRequest.Model) + openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError) + return + } + middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model) + Relay(c) +} diff --git a/controller/relay.go b/controller/relay.go index 5581d5bb..72d421e3 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -17,7 +17,6 @@ import ( "one-api/relay/constant" relayconstant "one-api/relay/constant" "one-api/service" - "one-api/setting" "strings" ) @@ -49,58 +48,6 @@ func wsHandler(c *gin.Context, ws *websocket.Conn, relayMode int) *dto.OpenAIErr return err } -func Playground(c *gin.Context) { - var openaiErr *dto.OpenAIErrorWithStatusCode - - defer func() { - if openaiErr != nil { - c.JSON(openaiErr.StatusCode, gin.H{ - "error": openaiErr.Error, - }) - } - }() - - useAccessToken := c.GetBool("use_access_token") - if useAccessToken { - openaiErr = service.OpenAIErrorWrapperLocal(errors.New("暂不支持使用 access token"), "access_token_not_supported", http.StatusBadRequest) - return - } - - playgroundRequest := &dto.PlayGroundRequest{} - err := common.UnmarshalBodyReusable(c, playgroundRequest) - if err != nil { - openaiErr = service.OpenAIErrorWrapperLocal(err, "unmarshal_request_failed", http.StatusBadRequest) - return - } - - if playgroundRequest.Model == "" { - openaiErr = service.OpenAIErrorWrapperLocal(errors.New("请选择模型"), "model_required", http.StatusBadRequest) - return - } - c.Set("original_model", playgroundRequest.Model) - group := playgroundRequest.Group - userGroup := c.GetString("group") - - if group == "" { - group = userGroup - } else { - if !setting.GroupInUserUsableGroups(group) && group != userGroup { - openaiErr = service.OpenAIErrorWrapperLocal(errors.New("无权访问该分组"), "group_not_allowed", http.StatusForbidden) - return - } - c.Set("group", group) - } - c.Set("token_name", "playground-"+group) - channel, err := model.CacheGetRandomSatisfiedChannel(group, playgroundRequest.Model, 0) - if err != nil { - message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, playgroundRequest.Model) - openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError) - return - } - middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model) - Relay(c) -} - func Relay(c *gin.Context) { relayMode := constant.Path2RelayMode(c.Request.URL.Path) requestId := c.GetString(common.RequestIdKey)