refactor: Playground controller
This commit is contained in:
66
controller/playground.go
Normal file
66
controller/playground.go
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/dto"
|
||||||
|
"one-api/middleware"
|
||||||
|
"one-api/model"
|
||||||
|
"one-api/service"
|
||||||
|
"one-api/setting"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Playground(c *gin.Context) {
|
||||||
|
var openaiErr *dto.OpenAIErrorWithStatusCode
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if openaiErr != nil {
|
||||||
|
c.JSON(openaiErr.StatusCode, gin.H{
|
||||||
|
"error": openaiErr.Error,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
useAccessToken := c.GetBool("use_access_token")
|
||||||
|
if useAccessToken {
|
||||||
|
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("暂不支持使用 access token"), "access_token_not_supported", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
playgroundRequest := &dto.PlayGroundRequest{}
|
||||||
|
err := common.UnmarshalBodyReusable(c, playgroundRequest)
|
||||||
|
if err != nil {
|
||||||
|
openaiErr = service.OpenAIErrorWrapperLocal(err, "unmarshal_request_failed", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if playgroundRequest.Model == "" {
|
||||||
|
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("请选择模型"), "model_required", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Set("original_model", playgroundRequest.Model)
|
||||||
|
group := playgroundRequest.Group
|
||||||
|
userGroup := c.GetString("group")
|
||||||
|
|
||||||
|
if group == "" {
|
||||||
|
group = userGroup
|
||||||
|
} else {
|
||||||
|
if !setting.GroupInUserUsableGroups(group) && group != userGroup {
|
||||||
|
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("无权访问该分组"), "group_not_allowed", http.StatusForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Set("group", group)
|
||||||
|
}
|
||||||
|
c.Set("token_name", "playground-"+group)
|
||||||
|
channel, err := model.CacheGetRandomSatisfiedChannel(group, playgroundRequest.Model, 0)
|
||||||
|
if err != nil {
|
||||||
|
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, playgroundRequest.Model)
|
||||||
|
openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
|
||||||
|
Relay(c)
|
||||||
|
}
|
||||||
@@ -17,7 +17,6 @@ import (
|
|||||||
"one-api/relay/constant"
|
"one-api/relay/constant"
|
||||||
relayconstant "one-api/relay/constant"
|
relayconstant "one-api/relay/constant"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"one-api/setting"
|
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -49,58 +48,6 @@ func wsHandler(c *gin.Context, ws *websocket.Conn, relayMode int) *dto.OpenAIErr
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func Playground(c *gin.Context) {
|
|
||||||
var openaiErr *dto.OpenAIErrorWithStatusCode
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
if openaiErr != nil {
|
|
||||||
c.JSON(openaiErr.StatusCode, gin.H{
|
|
||||||
"error": openaiErr.Error,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
useAccessToken := c.GetBool("use_access_token")
|
|
||||||
if useAccessToken {
|
|
||||||
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("暂不支持使用 access token"), "access_token_not_supported", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
playgroundRequest := &dto.PlayGroundRequest{}
|
|
||||||
err := common.UnmarshalBodyReusable(c, playgroundRequest)
|
|
||||||
if err != nil {
|
|
||||||
openaiErr = service.OpenAIErrorWrapperLocal(err, "unmarshal_request_failed", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if playgroundRequest.Model == "" {
|
|
||||||
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("请选择模型"), "model_required", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.Set("original_model", playgroundRequest.Model)
|
|
||||||
group := playgroundRequest.Group
|
|
||||||
userGroup := c.GetString("group")
|
|
||||||
|
|
||||||
if group == "" {
|
|
||||||
group = userGroup
|
|
||||||
} else {
|
|
||||||
if !setting.GroupInUserUsableGroups(group) && group != userGroup {
|
|
||||||
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("无权访问该分组"), "group_not_allowed", http.StatusForbidden)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.Set("group", group)
|
|
||||||
}
|
|
||||||
c.Set("token_name", "playground-"+group)
|
|
||||||
channel, err := model.CacheGetRandomSatisfiedChannel(group, playgroundRequest.Model, 0)
|
|
||||||
if err != nil {
|
|
||||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, playgroundRequest.Model)
|
|
||||||
openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
|
|
||||||
Relay(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Relay(c *gin.Context) {
|
func Relay(c *gin.Context) {
|
||||||
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
|
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
|
||||||
requestId := c.GetString(common.RequestIdKey)
|
requestId := c.GetString(common.RequestIdKey)
|
||||||
|
|||||||
Reference in New Issue
Block a user