From 0089157b83633a0892e5fe2b0960e7791f213d31 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Mon, 16 Jun 2025 00:37:22 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat(channel):=20enhance=20AddChann?= =?UTF-8?q?el=20functionality=20with=20structured=20request=20handling?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- controller/channel.go | 107 +++++++++++++++++++++++++++++++++--------- model/channel.go | 9 +++- model/main.go | 2 +- 3 files changed, 94 insertions(+), 24 deletions(-) diff --git a/controller/channel.go b/controller/channel.go index 1cfb7906..f2b9ad7e 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -250,9 +250,14 @@ func GetChannel(c *gin.Context) { return } +type AddChannelRequest struct { + Mode string `json:"mode"` + Channel *model.Channel `json:"channel"` +} + func AddChannel(c *gin.Context) { - channel := model.Channel{} - err := c.ShouldBindJSON(&channel) + addChannelRequest := AddChannelRequest{} + err := c.ShouldBindJSON(&addChannelRequest) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -260,19 +265,35 @@ func AddChannel(c *gin.Context) { }) return } - channel.CreatedTime = common.GetTimestamp() - keys := strings.Split(channel.Key, "\n") - if channel.Type == common.ChannelTypeVertexAi { - if channel.Other == "" { + if addChannelRequest.Channel == nil || addChannelRequest.Channel.Key == "" { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "channel cannot be empty", + }) + return + } + + // Validate the length of the model name + for _, m := range addChannelRequest.Channel.GetModels() { + if len(m) > 255 { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": fmt.Sprintf("模型名称过长: %s", m), + }) + return + } + } + if addChannelRequest.Channel.Type == common.ChannelTypeVertexAi { + if addChannelRequest.Channel.Other == "" { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "部署地区不能为空", }) return } else { - if common.IsJsonStr(channel.Other) { + if common.IsJsonStr(addChannelRequest.Channel.Other) { // must have default - regionMap := common.StrToMap(channel.Other) + regionMap := common.StrToMap(addChannelRequest.Channel.Other) if regionMap["default"] == nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -282,27 +303,69 @@ func AddChannel(c *gin.Context) { } } } - keys = []string{channel.Key} } + + addChannelRequest.Channel.CreatedTime = common.GetTimestamp() + keys := make([]string, 0) + switch addChannelRequest.Mode { + case "multi_to_single": + addChannelRequest.Channel.ChannelInfo.MultiKeyMode = true + if addChannelRequest.Channel.Type == common.ChannelTypeVertexAi { + if !common.IsJsonStr(addChannelRequest.Channel.Key) { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "Vertex AI 批量添加模式必须使用标准的JsonArray格式,例如[{key1}, {key2}...],请检查输入", + }) + return + } + } + keys = []string{addChannelRequest.Channel.Key} + case "batch": + if addChannelRequest.Channel.Type == common.ChannelTypeVertexAi { + // multi json + if !common.IsJsonStr(addChannelRequest.Channel.Key) { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "Vertex AI 批量添加模式必须使用标准的JsonArray格式,例如[{key1}, {key2}...],请检查输入", + }) + return + } + toMap := common.StrToMap(addChannelRequest.Channel.Key) + if toMap == nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "Vertex AI 批量添加模式必须使用标准的JsonArray格式,例如[{key1}, {key2}...],请检查输入", + }) + return + } + keys = make([]string, 0, len(toMap)) + for k := range toMap { + if k == "" { + continue + } + keys = append(keys, k) + } + } else { + keys = strings.Split(addChannelRequest.Channel.Key, "\n") + } + case "single": + keys = []string{addChannelRequest.Channel.Key} + default: + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "不支持的添加模式", + }) + return + } + channels := make([]model.Channel, 0, len(keys)) for _, key := range keys { if key == "" { continue } - localChannel := channel + localChannel := addChannelRequest.Channel localChannel.Key = key - // Validate the length of the model name - models := strings.Split(localChannel.Models, ",") - for _, model := range models { - if len(model) > 255 { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": fmt.Sprintf("模型名称过长: %s", model), - }) - return - } - } - channels = append(channels, localChannel) + channels = append(channels, *localChannel) } err = model.BatchInsertChannels(channels) if err != nil { diff --git a/model/channel.go b/model/channel.go index b5503eee..755bd0b2 100644 --- a/model/channel.go +++ b/model/channel.go @@ -9,6 +9,11 @@ import ( "gorm.io/gorm" ) +type ChannelInfo struct { + MultiKeyMode bool `json:"multi_key_mode"` // 是否多Key模式 + MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表,key index -> status +} + type Channel struct { Id int `json:"id"` Type int `json:"type" gorm:"default:0"` @@ -35,8 +40,10 @@ type Channel struct { AutoBan *int `json:"auto_ban" gorm:"default:1"` OtherInfo string `json:"other_info"` Tag *string `json:"tag" gorm:"index"` - Setting *string `json:"setting" gorm:"type:text"` + Setting *string `json:"setting" gorm:"type:text"` // 渠道额外设置 ParamOverride *string `json:"param_override" gorm:"type:text"` + // add after v0.8.5 + ChannelInfo ChannelInfo `json:"channel_info" gorm:"type:json"` } func (channel *Channel) GetModels() []string { diff --git a/model/main.go b/model/main.go index 965bba93..b7a5af5d 100644 --- a/model/main.go +++ b/model/main.go @@ -48,7 +48,7 @@ func initCol() { } } // log sql type and database type - common.SysLog("Using Log SQL Type: " + common.LogSqlType) + //common.SysLog("Using Log SQL Type: " + common.LogSqlType) } var DB *gorm.DB