diff --git a/controller/channel.go b/controller/channel.go index 4c45574f..c2430a04 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -350,6 +350,46 @@ func GetChannel(c *gin.Context) { return } +// validateChannel 通用的渠道校验函数 +func validateChannel(channel *model.Channel, isAdd bool) error { + // 校验 channel settings + if err := channel.ValidateSettings(); err != nil { + return fmt.Errorf("渠道额外设置[channel setting] 格式错误:%s", err.Error()) + } + + // 如果是添加操作,检查 channel 和 key 是否为空 + if isAdd { + if channel == nil || channel.Key == "" { + return fmt.Errorf("channel cannot be empty") + } + + // 检查模型名称长度是否超过 255 + for _, m := range channel.GetModels() { + if len(m) > 255 { + return fmt.Errorf("模型名称过长: %s", m) + } + } + } + + // VertexAI 特殊校验 + if channel.Type == constant.ChannelTypeVertexAi { + if channel.Other == "" { + return fmt.Errorf("部署地区不能为空") + } + + regionMap, err := common.StrToMap(channel.Other) + if err != nil { + return fmt.Errorf("部署地区必须是标准的Json格式,例如{\"default\": \"us-central1\", \"region2\": \"us-east1\"}") + } + + if regionMap["default"] == nil { + return fmt.Errorf("部署地区必须包含default字段") + } + } + + return nil +} + type AddChannelRequest struct { Mode string `json:"mode"` MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"` @@ -396,59 +436,15 @@ func AddChannel(c *gin.Context) { return } - err = addChannelRequest.Channel.ValidateSettings() - if err != nil { + // 使用统一的校验函数 + if err := validateChannel(addChannelRequest.Channel, true); err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, - "message": "channel setting 格式错误:" + err.Error(), + "message": err.Error(), }) return } - if addChannelRequest.Channel == nil || addChannelRequest.Channel.Key == "" { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": "channel cannot be empty", - }) - return - } - - // Validate the length of the model name - for _, m := range addChannelRequest.Channel.GetModels() { - if len(m) > 255 { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": fmt.Sprintf("模型名称过长: %s", m), - }) - return - } - } - if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi { - if addChannelRequest.Channel.Other == "" { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": "部署地区不能为空", - }) - return - } else { - regionMap, err := common.StrToMap(addChannelRequest.Channel.Other) - if err != nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": "部署地区必须是标准的Json格式,例如{\"default\": \"us-central1\", \"region2\": \"us-east1\"}", - }) - return - } - if regionMap["default"] == nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": "部署地区必须包含default字段", - }) - return - } - } - } - addChannelRequest.Channel.CreatedTime = common.GetTimestamp() keys := make([]string, 0) switch addChannelRequest.Mode { @@ -676,39 +672,15 @@ func UpdateChannel(c *gin.Context) { common.ApiError(c, err) return } - err = channel.ValidateSettings() - if err != nil { + + // 使用统一的校验函数 + if err := validateChannel(&channel.Channel, false); err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, - "message": "channel setting 格式错误:" + err.Error(), + "message": err.Error(), }) return } - if channel.Type == constant.ChannelTypeVertexAi { - if channel.Other == "" { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": "部署地区不能为空", - }) - return - } else { - regionMap, err := common.StrToMap(channel.Other) - if err != nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": "部署地区必须是标准的Json格式,例如{\"default\": \"us-central1\", \"region2\": \"us-east1\"}", - }) - return - } - if regionMap["default"] == nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": "部署地区必须包含default字段", - }) - return - } - } - } // Preserve existing ChannelInfo to ensure multi-key channels keep correct state even if the client does not send ChannelInfo in the request. originChannel, err := model.GetChannelById(channel.Id, false) if err != nil { @@ -731,6 +703,9 @@ func UpdateChannel(c *gin.Context) { common.ApiError(c, err) return } + if common.MemoryCacheEnabled { + model.InitChannelCache() + } channel.Key = "" c.JSON(http.StatusOK, gin.H{ "success": true, @@ -888,8 +863,9 @@ func GetTagModels(c *gin.Context) { // CopyChannel handles cloning an existing channel with its key. // POST /api/channel/copy/:id // Optional query params: -// suffix - string appended to the original name (default "_复制") -// reset_balance - bool, when true will reset balance & used_quota to 0 (default true) +// +// suffix - string appended to the original name (default "_复制") +// reset_balance - bool, when true will reset balance & used_quota to 0 (default true) func CopyChannel(c *gin.Context) { id, err := strconv.Atoi(c.Param("id")) if err != nil { @@ -914,7 +890,7 @@ func CopyChannel(c *gin.Context) { // clone channel clone := *origin // shallow copy is sufficient as we will overwrite primitives - clone.Id = 0 // let DB auto-generate + clone.Id = 0 // let DB auto-generate clone.CreatedTime = common.GetTimestamp() clone.Name = origin.Name + suffix clone.TestTime = 0