Files
new-api/controller/channel.go
t0ng7u aeb393e391 🚀 feat: Align search API with channel listing & fix sorting toggle
1. Backend
   • `controller/channel.go`
     – Added pagination (`p`, `page_size`) support to `SearchChannels`.
     – Added independent `type` filter (keeps `type_counts` unaffected).
     – Returned `total`, `type_counts` to match `/api/channel/` response.

2. Frontend
   • `ChannelsTable.js`
     – `loadChannels` / `searchChannels` now pass `p`, `page_size`, `id_sort`, `type`, `status` correctly.
     – Pagination, page-size selector and type tabs work for both normal list and search mode.
     – Switch for “ID sort” calls proper API and keeps UI state in sync.
     – Removed unnecessary `normalize` helper; `getFormValues` back to concise form.

Result
• Search mode and normal listing now share identical pagination and filtering behavior.
• Type tabs show correct counts even after searching.
• “ID Sort” toggle no longer inverses actual behaviour.
2025-06-24 05:13:47 +08:00

801 lines
18 KiB
Go

package controller
import (
"encoding/json"
"fmt"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"strings"
"github.com/gin-gonic/gin"
)
type OpenAIModel struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
Permission []struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
AllowCreateEngine bool `json:"allow_create_engine"`
AllowSampling bool `json:"allow_sampling"`
AllowLogprobs bool `json:"allow_logprobs"`
AllowSearchIndices bool `json:"allow_search_indices"`
AllowView bool `json:"allow_view"`
AllowFineTuning bool `json:"allow_fine_tuning"`
Organization string `json:"organization"`
Group string `json:"group"`
IsBlocking bool `json:"is_blocking"`
} `json:"permission"`
Root string `json:"root"`
Parent string `json:"parent"`
}
type OpenAIModelsResponse struct {
Data []OpenAIModel `json:"data"`
Success bool `json:"success"`
}
func parseStatusFilter(statusParam string) int {
switch strings.ToLower(statusParam) {
case "enabled", "1":
return common.ChannelStatusEnabled
case "disabled", "0":
return 0
default:
return -1
}
}
func GetAllChannels(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
pageSize, _ := strconv.Atoi(c.Query("page_size"))
if p < 1 {
p = 1
}
if pageSize < 1 {
pageSize = common.ItemsPerPage
}
channelData := make([]*model.Channel, 0)
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
statusParam := c.Query("status")
// statusFilter: -1 all, 1 enabled, 0 disabled (include auto & manual)
statusFilter := parseStatusFilter(statusParam)
// type filter
typeStr := c.Query("type")
typeFilter := -1
if typeStr != "" {
if t, err := strconv.Atoi(typeStr); err == nil {
typeFilter = t
}
}
var total int64
if enableTagMode {
tags, err := model.GetPaginatedTags((p-1)*pageSize, pageSize)
if err != nil {
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
return
}
for _, tag := range tags {
if tag == nil || *tag == "" {
continue
}
tagChannels, err := model.GetChannelsByTag(*tag, idSort)
if err != nil {
continue
}
filtered := make([]*model.Channel, 0)
for _, ch := range tagChannels {
if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
continue
}
if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
continue
}
if typeFilter >= 0 && ch.Type != typeFilter {
continue
}
filtered = append(filtered, ch)
}
channelData = append(channelData, filtered...)
}
total, _ = model.CountAllTags()
} else {
baseQuery := model.DB.Model(&model.Channel{})
if typeFilter >= 0 {
baseQuery = baseQuery.Where("type = ?", typeFilter)
}
if statusFilter == common.ChannelStatusEnabled {
baseQuery = baseQuery.Where("status = ?", common.ChannelStatusEnabled)
} else if statusFilter == 0 {
baseQuery = baseQuery.Where("status != ?", common.ChannelStatusEnabled)
}
baseQuery.Count(&total)
order := "priority desc"
if idSort {
order = "id desc"
}
err := baseQuery.Order(order).Limit(pageSize).Offset((p-1)*pageSize).Omit("key").Find(&channelData).Error
if err != nil {
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
return
}
}
countQuery := model.DB.Model(&model.Channel{})
if statusFilter == common.ChannelStatusEnabled {
countQuery = countQuery.Where("status = ?", common.ChannelStatusEnabled)
} else if statusFilter == 0 {
countQuery = countQuery.Where("status != ?", common.ChannelStatusEnabled)
}
var results []struct {
Type int64
Count int64
}
_ = countQuery.Select("type, count(*) as count").Group("type").Find(&results).Error
typeCounts := make(map[int64]int64)
for _, r := range results {
typeCounts[r.Type] = r.Count
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": gin.H{
"items": channelData,
"total": total,
"page": p,
"page_size": pageSize,
"type_counts": typeCounts,
},
})
return
}
func FetchUpstreamModels(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
channel, err := model.GetChannelById(id, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
baseURL := common.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() != "" {
baseURL = channel.GetBaseURL()
}
url := fmt.Sprintf("%s/v1/models", baseURL)
switch channel.Type {
case common.ChannelTypeGemini:
url = fmt.Sprintf("%s/v1beta/openai/models", baseURL)
case common.ChannelTypeAli:
url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
}
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
var result OpenAIModelsResponse
if err = json.Unmarshal(body, &result); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": fmt.Sprintf("解析响应失败: %s", err.Error()),
})
return
}
var ids []string
for _, model := range result.Data {
id := model.ID
if channel.Type == common.ChannelTypeGemini {
id = strings.TrimPrefix(id, "models/")
}
ids = append(ids, id)
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": ids,
})
}
func FixChannelsAbilities(c *gin.Context) {
count, err := model.FixAbility()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": count,
})
}
func SearchChannels(c *gin.Context) {
keyword := c.Query("keyword")
group := c.Query("group")
modelKeyword := c.Query("model")
statusParam := c.Query("status")
statusFilter := parseStatusFilter(statusParam)
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
channelData := make([]*model.Channel, 0)
if enableTagMode {
tags, err := model.SearchTags(keyword, group, modelKeyword, idSort)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
for _, tag := range tags {
if tag != nil && *tag != "" {
tagChannel, err := model.GetChannelsByTag(*tag, idSort)
if err == nil {
channelData = append(channelData, tagChannel...)
}
}
}
} else {
channels, err := model.SearchChannels(keyword, group, modelKeyword, idSort)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
channelData = channels
}
if statusFilter == common.ChannelStatusEnabled || statusFilter == 0 {
filtered := make([]*model.Channel, 0, len(channelData))
for _, ch := range channelData {
if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
continue
}
if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
continue
}
filtered = append(filtered, ch)
}
channelData = filtered
}
// calculate type counts for search results
typeCounts := make(map[int64]int64)
for _, channel := range channelData {
typeCounts[int64(channel.Type)]++
}
typeParam := c.Query("type")
typeFilter := -1
if typeParam != "" {
if tp, err := strconv.Atoi(typeParam); err == nil {
typeFilter = tp
}
}
if typeFilter >= 0 {
filtered := make([]*model.Channel, 0, len(channelData))
for _, ch := range channelData {
if ch.Type == typeFilter {
filtered = append(filtered, ch)
}
}
channelData = filtered
}
page, _ := strconv.Atoi(c.DefaultQuery("p", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
if page < 1 {
page = 1
}
if pageSize <= 0 {
pageSize = 20
}
total := len(channelData)
startIdx := (page - 1) * pageSize
if startIdx > total {
startIdx = total
}
endIdx := startIdx + pageSize
if endIdx > total {
endIdx = total
}
pagedData := channelData[startIdx:endIdx]
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": gin.H{
"items": pagedData,
"total": total,
"type_counts": typeCounts,
},
})
return
}
func GetChannel(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
channel, err := model.GetChannelById(id, false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": channel,
})
return
}
func AddChannel(c *gin.Context) {
channel := model.Channel{}
err := c.ShouldBindJSON(&channel)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
channel.CreatedTime = common.GetTimestamp()
keys := strings.Split(channel.Key, "\n")
if channel.Type == common.ChannelTypeVertexAi {
if channel.Other == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "部署地区不能为空",
})
return
} else {
if common.IsJsonStr(channel.Other) {
// must have default
regionMap := common.StrToMap(channel.Other)
if regionMap["default"] == nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "部署地区必须包含default字段",
})
return
}
}
}
keys = []string{channel.Key}
}
channels := make([]model.Channel, 0, len(keys))
for _, key := range keys {
if key == "" {
continue
}
localChannel := channel
localChannel.Key = key
// Validate the length of the model name
models := strings.Split(localChannel.Models, ",")
for _, model := range models {
if len(model) > 255 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": fmt.Sprintf("模型名称过长: %s", model),
})
return
}
}
channels = append(channels, localChannel)
}
err = model.BatchInsertChannels(channels)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}
func DeleteChannel(c *gin.Context) {
id, _ := strconv.Atoi(c.Param("id"))
channel := model.Channel{Id: id}
err := channel.Delete()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}
func DeleteDisabledChannel(c *gin.Context) {
rows, err := model.DeleteDisabledChannel()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": rows,
})
return
}
type ChannelTag struct {
Tag string `json:"tag"`
NewTag *string `json:"new_tag"`
Priority *int64 `json:"priority"`
Weight *uint `json:"weight"`
ModelMapping *string `json:"model_mapping"`
Models *string `json:"models"`
Groups *string `json:"groups"`
}
func DisableTagChannels(c *gin.Context) {
channelTag := ChannelTag{}
err := c.ShouldBindJSON(&channelTag)
if err != nil || channelTag.Tag == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
err = model.DisableChannelByTag(channelTag.Tag)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}
func EnableTagChannels(c *gin.Context) {
channelTag := ChannelTag{}
err := c.ShouldBindJSON(&channelTag)
if err != nil || channelTag.Tag == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
err = model.EnableChannelByTag(channelTag.Tag)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}
func EditTagChannels(c *gin.Context) {
channelTag := ChannelTag{}
err := c.ShouldBindJSON(&channelTag)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
if channelTag.Tag == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "tag不能为空",
})
return
}
err = model.EditChannelByTag(channelTag.Tag, channelTag.NewTag, channelTag.ModelMapping, channelTag.Models, channelTag.Groups, channelTag.Priority, channelTag.Weight)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}
type ChannelBatch struct {
Ids []int `json:"ids"`
Tag *string `json:"tag"`
}
func DeleteChannelBatch(c *gin.Context) {
channelBatch := ChannelBatch{}
err := c.ShouldBindJSON(&channelBatch)
if err != nil || len(channelBatch.Ids) == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
err = model.BatchDeleteChannels(channelBatch.Ids)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": len(channelBatch.Ids),
})
return
}
func UpdateChannel(c *gin.Context) {
channel := model.Channel{}
err := c.ShouldBindJSON(&channel)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
if channel.Type == common.ChannelTypeVertexAi {
if channel.Other == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "部署地区不能为空",
})
return
} else {
if common.IsJsonStr(channel.Other) {
// must have default
regionMap := common.StrToMap(channel.Other)
if regionMap["default"] == nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "部署地区必须包含default字段",
})
return
}
}
}
}
err = channel.Update()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
channel.Key = ""
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": channel,
})
return
}
func FetchModels(c *gin.Context) {
var req struct {
BaseURL string `json:"base_url"`
Type int `json:"type"`
Key string `json:"key"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": "Invalid request",
})
return
}
baseURL := req.BaseURL
if baseURL == "" {
baseURL = common.ChannelBaseURLs[req.Type]
}
client := &http.Client{}
url := fmt.Sprintf("%s/v1/models", baseURL)
request, err := http.NewRequest("GET", url, nil)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": err.Error(),
})
return
}
// remove line breaks and extra spaces.
key := strings.TrimSpace(req.Key)
// If the key contains a line break, only take the first part.
key = strings.Split(key, "\n")[0]
request.Header.Set("Authorization", "Bearer "+key)
response, err := client.Do(request)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": err.Error(),
})
return
}
//check status code
if response.StatusCode != http.StatusOK {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": "Failed to fetch models",
})
return
}
defer response.Body.Close()
var result struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
}
if err := json.NewDecoder(response.Body).Decode(&result); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": err.Error(),
})
return
}
var models []string
for _, model := range result.Data {
models = append(models, model.ID)
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": models,
})
}
func BatchSetChannelTag(c *gin.Context) {
channelBatch := ChannelBatch{}
err := c.ShouldBindJSON(&channelBatch)
if err != nil || len(channelBatch.Ids) == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
err = model.BatchSetChannelTag(channelBatch.Ids, channelBatch.Tag)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": len(channelBatch.Ids),
})
return
}
func GetTagModels(c *gin.Context) {
tag := c.Query("tag")
if tag == "" {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": "tag不能为空",
})
return
}
channels, err := model.GetChannelsByTag(tag, false) // Assuming false for idSort is fine here
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": err.Error(),
})
return
}
var longestModels string
maxLength := 0
// Find the longest models string among all channels with the given tag
for _, channel := range channels {
if channel.Models != "" {
currentModels := strings.Split(channel.Models, ",")
if len(currentModels) > maxLength {
maxLength = len(currentModels)
longestModels = channel.Models
}
}
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": longestModels,
})
return
}