diff --git a/controller/channel.go b/controller/channel.go
index 78525950..dae961d5 100644
--- a/controller/channel.go
+++ b/controller/channel.go
@@ -57,10 +57,37 @@ func GetAllChannels(c *gin.Context) {
})
return
}
+ tags := make(map[string]bool)
+ channelData := make([]*model.Channel, 0, len(channels))
+ tagChannels := make([]*model.Channel, 0)
+ for _, channel := range channels {
+ channelTag := channel.GetTag()
+ if channelTag != "" && !tags[channelTag] {
+ tags[channelTag] = true
+ tagChannel, err := model.GetChannelsByTag(channelTag)
+ if err == nil {
+ tagChannels = append(tagChannels, tagChannel...)
+ }
+ } else {
+ channelData = append(channelData, channel)
+ }
+ }
+ for i, channel := range tagChannels {
+ find := false
+ for _, can := range channelData {
+ if channel.Id == can.Id {
+ find = true
+ break
+ }
+ }
+ if !find {
+ channelData = append(channelData, tagChannels[i])
+ }
+ }
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
- "data": channels,
+ "data": channelData,
})
return
}
@@ -144,8 +171,8 @@ func SearchChannels(c *gin.Context) {
keyword := c.Query("keyword")
group := c.Query("group")
modelKeyword := c.Query("model")
- //idSort, _ := strconv.ParseBool(c.Query("id_sort"))
- channels, err := model.SearchChannels(keyword, group, modelKeyword)
+ idSort, _ := strconv.ParseBool(c.Query("id_sort"))
+ channels, err := model.SearchChannels(keyword, group, modelKeyword, idSort)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -279,6 +306,98 @@ func DeleteDisabledChannel(c *gin.Context) {
return
}
+type ChannelTag struct {
+ Tag string `json:"tag"`
+ NewTag *string `json:"new_tag"`
+ Priority *int64 `json:"priority"`
+ Weight *uint `json:"weight"`
+ ModelMapping *string `json:"model_mapping"`
+ Models *string `json:"models"`
+ Groups *string `json:"groups"`
+}
+
+func DisableTagChannels(c *gin.Context) {
+ channelTag := ChannelTag{}
+ err := c.ShouldBindJSON(&channelTag)
+ if err != nil || channelTag.Tag == "" {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "参数错误",
+ })
+ return
+ }
+ err = model.DisableChannelByTag(channelTag.Tag)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ })
+ return
+}
+
+func EnableTagChannels(c *gin.Context) {
+ channelTag := ChannelTag{}
+ err := c.ShouldBindJSON(&channelTag)
+ if err != nil || channelTag.Tag == "" {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "参数错误",
+ })
+ return
+ }
+ err = model.EnableChannelByTag(channelTag.Tag)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ })
+ return
+}
+
+func EditTagChannels(c *gin.Context) {
+ channelTag := ChannelTag{}
+ err := c.ShouldBindJSON(&channelTag)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "参数错误",
+ })
+ return
+ }
+ if channelTag.Tag == "" {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "tag不能为空",
+ })
+ return
+ }
+ err = model.EditChannelByTag(channelTag.Tag, channelTag.NewTag, channelTag.ModelMapping, channelTag.Models, channelTag.Groups, channelTag.Priority, channelTag.Weight)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ })
+ return
+}
+
type ChannelBatch struct {
Ids []int `json:"ids"`
}
diff --git a/model/ability.go b/model/ability.go
index 115ceb19..8e084cf9 100644
--- a/model/ability.go
+++ b/model/ability.go
@@ -10,12 +10,13 @@ import (
)
type Ability struct {
- Group string `json:"group" gorm:"type:varchar(64);primaryKey;autoIncrement:false"`
- Model string `json:"model" gorm:"type:varchar(64);primaryKey;autoIncrement:false"`
- ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
- Enabled bool `json:"enabled"`
- Priority *int64 `json:"priority" gorm:"bigint;default:0;index"`
- Weight uint `json:"weight" gorm:"default:0;index"`
+ Group string `json:"group" gorm:"type:varchar(64);primaryKey;autoIncrement:false"`
+ Model string `json:"model" gorm:"type:varchar(64);primaryKey;autoIncrement:false"`
+ ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
+ Enabled bool `json:"enabled"`
+ Priority *int64 `json:"priority" gorm:"bigint;default:0;index"`
+ Weight uint `json:"weight" gorm:"default:0;index"`
+ Tag *string `json:"tag" gorm:"index"`
}
func GetGroupModels(group string) []string {
@@ -149,6 +150,7 @@ func (channel *Channel) AddAbilities() error {
Enabled: channel.Status == common.ChannelStatusEnabled,
Priority: channel.Priority,
Weight: uint(channel.GetWeight()),
+ Tag: channel.Tag,
}
abilities = append(abilities, ability)
}
@@ -190,6 +192,24 @@ func UpdateAbilityStatus(channelId int, status bool) error {
return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error
}
+func UpdateAbilityStatusByTag(tag string, status bool) error {
+ return DB.Model(&Ability{}).Where("tag = ?", tag).Select("enabled").Update("enabled", status).Error
+}
+
+func UpdateAbilityByTag(tag string, newTag *string, priority *int64, weight *uint) error {
+ ability := Ability{}
+ if newTag != nil {
+ ability.Tag = newTag
+ }
+ if priority != nil {
+ ability.Priority = priority
+ }
+ if weight != nil {
+ ability.Weight = *weight
+ }
+ return DB.Model(&Ability{}).Where("tag = ?", tag).Updates(ability).Error
+}
+
func FixAbility() (int, error) {
var channelIds []int
count := 0
diff --git a/model/channel.go b/model/channel.go
index 34aae68a..3fcb7611 100644
--- a/model/channel.go
+++ b/model/channel.go
@@ -32,6 +32,7 @@ type Channel struct {
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
AutoBan *int `json:"auto_ban" gorm:"default:1"`
OtherInfo string `json:"other_info"`
+ Tag *string `json:"tag" gorm:"index"`
}
func (channel *Channel) GetModels() []string {
@@ -61,6 +62,17 @@ func (channel *Channel) SetOtherInfo(otherInfo map[string]interface{}) {
channel.OtherInfo = string(otherInfoBytes)
}
+func (channel *Channel) GetTag() string {
+ if channel.Tag == nil {
+ return ""
+ }
+ return *channel.Tag
+}
+
+func (channel *Channel) SetTag(tag string) {
+ channel.Tag = &tag
+}
+
func (channel *Channel) GetAutoBan() bool {
if channel.AutoBan == nil {
return false
@@ -87,7 +99,13 @@ func GetAllChannels(startIdx int, num int, selectAll bool, idSort bool) ([]*Chan
return channels, err
}
-func SearchChannels(keyword string, group string, model string) ([]*Channel, error) {
+func GetChannelsByTag(tag string) ([]*Channel, error) {
+ var channels []*Channel
+ err := DB.Where("tag = ?", tag).Find(&channels).Error
+ return channels, err
+}
+
+func SearchChannels(keyword string, group string, model string, idSort bool) ([]*Channel, error) {
var channels []*Channel
keyCol := "`key`"
groupCol := "`group`"
@@ -100,6 +118,11 @@ func SearchChannels(keyword string, group string, model string) ([]*Channel, err
modelsCol = `"models"`
}
+ order := "priority desc"
+ if idSort {
+ order = "id desc"
+ }
+
// 构造基础查询
baseQuery := DB.Model(&Channel{}).Omit(keyCol)
@@ -122,7 +145,7 @@ func SearchChannels(keyword string, group string, model string) ([]*Channel, err
}
// 执行查询
- err := baseQuery.Where(whereClause, args...).Order("priority desc").Find(&channels).Error
+ err := baseQuery.Where(whereClause, args...).Order(order).Find(&channels).Error
if err != nil {
return nil, err
}
@@ -288,6 +311,74 @@ func UpdateChannelStatusById(id int, status int, reason string) {
}
+func EnableChannelByTag(tag string) error {
+ err := DB.Model(&Channel{}).Where("tag = ?", tag).Update("status", common.ChannelStatusEnabled).Error
+ if err != nil {
+ return err
+ }
+ err = UpdateAbilityStatusByTag(tag, true)
+ return err
+}
+
+func DisableChannelByTag(tag string) error {
+ err := DB.Model(&Channel{}).Where("tag = ?", tag).Update("status", common.ChannelStatusManuallyDisabled).Error
+ if err != nil {
+ return err
+ }
+ err = UpdateAbilityStatusByTag(tag, false)
+ return err
+}
+
+func EditChannelByTag(tag string, newTag *string, modelMapping *string, models *string, group *string, priority *int64, weight *uint) error {
+ updateData := Channel{}
+ shouldReCreateAbilities := false
+ updatedTag := tag
+ // 如果 newTag 不为空且不等于 tag,则更新 tag
+ if newTag != nil && *newTag != tag {
+ updateData.Tag = newTag
+ updatedTag = *newTag
+ }
+ if modelMapping != nil && *modelMapping != "" {
+ updateData.ModelMapping = modelMapping
+ }
+ if models != nil && *models != "" {
+ shouldReCreateAbilities = true
+ updateData.Models = *models
+ }
+ if group != nil && *group != "" {
+ shouldReCreateAbilities = true
+ updateData.Group = *group
+ }
+ if priority != nil {
+ updateData.Priority = priority
+ }
+ if weight != nil {
+ updateData.Weight = weight
+ }
+
+ err := DB.Model(&Channel{}).Where("tag = ?", tag).Updates(updateData).Error
+ if err != nil {
+ return err
+ }
+ if shouldReCreateAbilities {
+ channels, err := GetChannelsByTag(updatedTag)
+ if err == nil {
+ for _, channel := range channels {
+ err = channel.UpdateAbilities()
+ if err != nil {
+ common.SysError("failed to update abilities: " + err.Error())
+ }
+ }
+ }
+ } else {
+ err := UpdateAbilityByTag(tag, newTag, priority, weight)
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
func UpdateChannelUsedQuota(id int, quota int) {
if common.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota)
diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go
index 75fdf4e3..bac0578c 100644
--- a/relay/channel/openai/relay-openai.go
+++ b/relay/channel/openai/relay-openai.go
@@ -98,6 +98,11 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
shouldSendLastResp = false
}
}
+ for _, choice := range lastStreamResponse.Choices {
+ if choice.FinishReason != nil {
+ shouldSendLastResp = true
+ }
+ }
}
if shouldSendLastResp {
service.StringData(c, lastStreamData)
diff --git a/router/api-router.go b/router/api-router.go
index 3b7eb36a..81a1341b 100644
--- a/router/api-router.go
+++ b/router/api-router.go
@@ -91,6 +91,9 @@ func SetApiRouter(router *gin.Engine) {
channelRoute.POST("/", controller.AddChannel)
channelRoute.PUT("/", controller.UpdateChannel)
channelRoute.DELETE("/disabled", controller.DeleteDisabledChannel)
+ channelRoute.POST("/tag/disabled", controller.DisableTagChannels)
+ channelRoute.POST("/tag/enabled", controller.EnableTagChannels)
+ channelRoute.PUT("/tag", controller.EditTagChannels)
channelRoute.DELETE("/:id", controller.DeleteChannel)
channelRoute.POST("/batch", controller.DeleteChannelBatch)
channelRoute.POST("/fix", controller.FixChannelsAbilities)
diff --git a/web/src/components/ChannelsTable.js b/web/src/components/ChannelsTable.js
index d09f34fa..2ffa83fe 100644
--- a/web/src/components/ChannelsTable.js
+++ b/web/src/components/ChannelsTable.js
@@ -7,20 +7,21 @@ import {
showInfo,
showSuccess,
showWarning,
- timestamp2string,
+ timestamp2string
} from '../helpers';
import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants';
import {
+ getQuotaPerUnit,
renderGroup,
renderNumberWithPoint,
- renderQuota,
+ renderQuota, renderQuotaWithPrompt
} from '../helpers/render';
import {
Button, Divider,
Dropdown,
- Form,
- InputNumber,
+ Form, Input,
+ InputNumber, Modal,
Popconfirm,
Space,
SplitButtonGroup,
@@ -28,11 +29,13 @@ import {
Table,
Tag,
Tooltip,
- Typography,
+ Typography
} from '@douyinfe/semi-ui';
import EditChannel from '../pages/Channel/EditChannel';
-import { IconTreeTriangleDown } from '@douyinfe/semi-icons';
+import { IconList, IconTreeTriangleDown } from '@douyinfe/semi-icons';
import { loadChannelModels } from './utils.js';
+import EditTagModal from '../pages/Channel/EditTagModal.js';
+import TextNumberInput from './custom/TextNumberInput.js';
function renderTimestamp(timestamp) {
return <>{timestamp2string(timestamp)}>;
@@ -49,12 +52,26 @@ function renderType(type) {
type2label[0] = { value: 0, text: '未知类型', color: 'grey' };
}
return (
-