From f2809917f856bef6a68bafa3f607768a2df8bf67 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Sun, 15 Dec 2024 15:52:41 +0800 Subject: [PATCH] feat: implement channel settings configuration fix #620 --- constant/channel_setting.go | 5 +++ middleware/distributor.go | 1 + model/channel.go | 21 +++++++++++ relay/channel/openai/relay-openai.go | 56 +++++++++++++++++++++++++--- relay/common/relay_info.go | 8 +++- web/src/pages/Channel/EditChannel.js | 36 ++++++++++++++++++ 6 files changed, 120 insertions(+), 7 deletions(-) create mode 100644 constant/channel_setting.go diff --git a/constant/channel_setting.go b/constant/channel_setting.go new file mode 100644 index 00000000..71b9f58b --- /dev/null +++ b/constant/channel_setting.go @@ -0,0 +1,5 @@ +package constant + +var ( + ForceFormat = "force_format" // ForceFormat 强制格式化为OpenAI格式 +) diff --git a/middleware/distributor.go b/middleware/distributor.go index 9bab3e99..022dd0af 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -213,6 +213,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode c.Set("channel_id", channel.Id) c.Set("channel_name", channel.Name) c.Set("channel_type", channel.Type) + c.Set("channel_setting", channel.GetSetting()) if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization { c.Set("channel_organization", *channel.OpenAIOrganization) } diff --git a/model/channel.go b/model/channel.go index 09094611..af7c4279 100644 --- a/model/channel.go +++ b/model/channel.go @@ -34,6 +34,7 @@ type Channel struct { AutoBan *int `json:"auto_ban" gorm:"default:1"` OtherInfo string `json:"other_info"` Tag *string `json:"tag" gorm:"index"` + Setting string `json:"setting" gorm:"type:text"` } func (channel *Channel) GetModels() []string { @@ -469,3 +470,23 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str return tags, nil } + +func (channel *Channel) GetSetting() map[string]interface{} { + setting := make(map[string]interface{}) + if channel.Setting != "" { + err := json.Unmarshal([]byte(channel.Setting), &setting) + if err != nil { + common.SysError("failed to unmarshal setting: " + err.Error()) + } + } + return setting +} + +func (channel *Channel) SetSetting(setting map[string]interface{}) { + settingBytes, err := json.Marshal(setting) + if err != nil { + common.SysError("failed to marshal setting: " + err.Error()) + return + } + channel.Setting = string(settingBytes) +} diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index bac0578c..bd39b904 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -5,9 +5,6 @@ import ( "bytes" "encoding/json" "fmt" - "github.com/bytedance/gopkg/util/gopool" - "github.com/gin-gonic/gin" - "github.com/gorilla/websocket" "io" "net/http" "one-api/common" @@ -19,9 +16,33 @@ import ( "strings" "sync" "time" + + "github.com/bytedance/gopkg/util/gopool" + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" ) +func sendStreamData(c *gin.Context, data string, forceFormat bool) error { + if data == "" { + return nil + } + + if forceFormat { + var lastStreamResponse dto.ChatCompletionsStreamResponse + if err := json.Unmarshal(common.StringToByteSlice(data), &lastStreamResponse); err != nil { + return err + } + return service.ObjectData(c, lastStreamResponse) + } + return service.StringData(c, data) +} + func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + if resp == nil || resp.Body == nil { + common.LogError(c, "invalid response or response body") + return service.OpenAIErrorWrapper(fmt.Errorf("invalid response"), "invalid_response", http.StatusInternalServerError), nil + } + containStreamUsage := false var responseId string var createAt int64 = 0 @@ -31,6 +52,13 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel var responseTextBuilder strings.Builder var usage = &dto.Usage{} var streamItems []string // store stream items + var forceFormat bool + + if info.ChannelType == common.ChannelTypeCustom { + if forceFmt, ok := info.ChannelSetting["force_format"].(bool); ok { + forceFormat = forceFmt + } + } toolCount := 0 scanner := bufio.NewScanner(resp.Body) @@ -62,7 +90,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel data = data[6:] if !strings.HasPrefix(data, "[DONE]") { if lastStreamData != "" { - err := service.StringData(c, lastStreamData) + err := sendStreamData(c, lastStreamData, forceFormat) if err != nil { common.LogError(c, "streaming error: "+err.Error()) } @@ -105,7 +133,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel } } if shouldSendLastResp { - service.StringData(c, lastStreamData) + sendStreamData(c, lastStreamData, forceFormat) } // 计算token @@ -375,6 +403,10 @@ func getTextFromJSON(body []byte) (string, error) { } func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.RealtimeUsage) { + if info == nil || info.ClientWs == nil || info.TargetWs == nil { + return service.OpenAIErrorWrapper(fmt.Errorf("invalid websocket connection"), "invalid_connection", http.StatusBadRequest), nil + } + info.IsStream = true clientConn := info.ClientWs targetConn := info.TargetWs @@ -390,6 +422,11 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op sumUsage := &dto.RealtimeUsage{} gopool.Go(func() { + defer func() { + if r := recover(); r != nil { + errChan <- fmt.Errorf("panic in client reader: %v", r) + } + }() for { select { case <-c.Done(): @@ -445,6 +482,11 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op }) gopool.Go(func() { + defer func() { + if r := recover(); r != nil { + errChan <- fmt.Errorf("panic in target reader: %v", r) + } + }() for { select { case <-c.Done(): @@ -568,6 +610,10 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op } func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.RealtimeUsage, totalUsage *dto.RealtimeUsage) error { + if usage == nil || totalUsage == nil { + return fmt.Errorf("invalid usage pointer") + } + totalUsage.TotalTokens += usage.TotalTokens totalUsage.InputTokens += usage.InputTokens totalUsage.OutputTokens += usage.OutputTokens diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index ae95cd51..3bfc2ef6 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -1,13 +1,14 @@ package common import ( - "github.com/gin-gonic/gin" - "github.com/gorilla/websocket" "one-api/common" "one-api/dto" "one-api/relay/constant" "strings" "time" + + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" ) type RelayInfo struct { @@ -43,6 +44,7 @@ type RelayInfo struct { RealtimeTools []dto.RealTimeTool IsFirstRequest bool AudioUsage bool + ChannelSetting map[string]interface{} } func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo { @@ -57,6 +59,7 @@ func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo { func GenRelayInfo(c *gin.Context) *RelayInfo { channelType := c.GetInt("channel_type") channelId := c.GetInt("channel_id") + channelSetting := c.GetStringMap("channel_setting") tokenId := c.GetInt("token_id") tokenKey := c.GetString("token_key") @@ -87,6 +90,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { ApiVersion: c.GetString("api_version"), ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), Organization: c.GetString("channel_organization"), + ChannelSetting: channelSetting, } if strings.HasPrefix(c.Request.URL.Path, "/pg") { info.IsPlayground = true diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index eab82a8e..3e387a26 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -964,6 +964,42 @@ const EditChannel = (props) => { value={inputs.weight} autoComplete="new-password" /> + {inputs.type === 8 && ( + <> +