Files
new-api/controller/channel.go

1462 lines
37 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package controller
import (
"encoding/json"
"fmt"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/model"
"strconv"
"strings"
"github.com/gin-gonic/gin"
)
type OpenAIModel struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
Permission []struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
AllowCreateEngine bool `json:"allow_create_engine"`
AllowSampling bool `json:"allow_sampling"`
AllowLogprobs bool `json:"allow_logprobs"`
AllowSearchIndices bool `json:"allow_search_indices"`
AllowView bool `json:"allow_view"`
AllowFineTuning bool `json:"allow_fine_tuning"`
Organization string `json:"organization"`
Group string `json:"group"`
IsBlocking bool `json:"is_blocking"`
} `json:"permission"`
Root string `json:"root"`
Parent string `json:"parent"`
}
type GoogleOpenAICompatibleModels []struct {
Name string `json:"name"`
Version string `json:"version"`
DisplayName string `json:"displayName"`
Description string `json:"description,omitempty"`
InputTokenLimit int `json:"inputTokenLimit"`
OutputTokenLimit int `json:"outputTokenLimit"`
SupportedGenerationMethods []string `json:"supportedGenerationMethods"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK int `json:"topK,omitempty"`
MaxTemperature int `json:"maxTemperature,omitempty"`
}
type OpenAIModelsResponse struct {
Data []OpenAIModel `json:"data"`
Success bool `json:"success"`
}
type GoogleOpenAICompatibleResponse struct {
Models []GoogleOpenAICompatibleModels `json:"models"`
NextPageToken string `json:"nextPageToken"`
}
func parseStatusFilter(statusParam string) int {
switch strings.ToLower(statusParam) {
case "enabled", "1":
return common.ChannelStatusEnabled
case "disabled", "0":
return 0
default:
return -1
}
}
func clearChannelInfo(channel *model.Channel) {
if channel.ChannelInfo.IsMultiKey {
channel.ChannelInfo.MultiKeyDisabledReason = nil
channel.ChannelInfo.MultiKeyDisabledTime = nil
}
}
func GetAllChannels(c *gin.Context) {
pageInfo := common.GetPageQuery(c)
channelData := make([]*model.Channel, 0)
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
statusParam := c.Query("status")
// statusFilter: -1 all, 1 enabled, 0 disabled (include auto & manual)
statusFilter := parseStatusFilter(statusParam)
// type filter
typeStr := c.Query("type")
typeFilter := -1
if typeStr != "" {
if t, err := strconv.Atoi(typeStr); err == nil {
typeFilter = t
}
}
var total int64
if enableTagMode {
tags, err := model.GetPaginatedTags(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
if err != nil {
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
return
}
for _, tag := range tags {
if tag == nil || *tag == "" {
continue
}
tagChannels, err := model.GetChannelsByTag(*tag, idSort)
if err != nil {
continue
}
filtered := make([]*model.Channel, 0)
for _, ch := range tagChannels {
if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
continue
}
if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
continue
}
if typeFilter >= 0 && ch.Type != typeFilter {
continue
}
filtered = append(filtered, ch)
}
channelData = append(channelData, filtered...)
}
total, _ = model.CountAllTags()
} else {
baseQuery := model.DB.Model(&model.Channel{})
if typeFilter >= 0 {
baseQuery = baseQuery.Where("type = ?", typeFilter)
}
if statusFilter == common.ChannelStatusEnabled {
baseQuery = baseQuery.Where("status = ?", common.ChannelStatusEnabled)
} else if statusFilter == 0 {
baseQuery = baseQuery.Where("status != ?", common.ChannelStatusEnabled)
}
baseQuery.Count(&total)
order := "priority desc"
if idSort {
order = "id desc"
}
err := baseQuery.Order(order).Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Omit("key").Find(&channelData).Error
if err != nil {
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
return
}
}
for _, datum := range channelData {
clearChannelInfo(datum)
}
countQuery := model.DB.Model(&model.Channel{})
if statusFilter == common.ChannelStatusEnabled {
countQuery = countQuery.Where("status = ?", common.ChannelStatusEnabled)
} else if statusFilter == 0 {
countQuery = countQuery.Where("status != ?", common.ChannelStatusEnabled)
}
var results []struct {
Type int64
Count int64
}
_ = countQuery.Select("type, count(*) as count").Group("type").Find(&results).Error
typeCounts := make(map[int64]int64)
for _, r := range results {
typeCounts[r.Type] = r.Count
}
common.ApiSuccess(c, gin.H{
"items": channelData,
"total": total,
"page": pageInfo.GetPage(),
"page_size": pageInfo.GetPageSize(),
"type_counts": typeCounts,
})
return
}
func FetchUpstreamModels(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
common.ApiError(c, err)
return
}
channel, err := model.GetChannelById(id, true)
if err != nil {
common.ApiError(c, err)
return
}
baseURL := constant.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() != "" {
baseURL = channel.GetBaseURL()
}
var url string
switch channel.Type {
case constant.ChannelTypeGemini:
// curl https://example.com/v1beta/models?key=$GEMINI_API_KEY
url = fmt.Sprintf("%s/v1beta/openai/models?key=%s", baseURL, channel.Key)
case constant.ChannelTypeAli:
url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
default:
url = fmt.Sprintf("%s/v1/models", baseURL)
}
// 获取响应体 - 根据渠道类型决定是否添加 AuthHeader
var body []byte
if channel.Type == constant.ChannelTypeGemini {
body, err = GetResponseBody("GET", url, channel, nil) // I don't know why, but Gemini requires no AuthHeader
} else {
body, err = GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
}
if err != nil {
common.ApiError(c, err)
return
}
var result OpenAIModelsResponse
var parseSuccess bool
// 适配特殊格式
switch channel.Type {
case constant.ChannelTypeGemini:
var googleResult GoogleOpenAICompatibleResponse
if err = json.Unmarshal(body, &googleResult); err == nil {
// 转换Google格式到OpenAI格式
for _, model := range googleResult.Models {
for _, gModel := range model {
result.Data = append(result.Data, OpenAIModel{
ID: gModel.Name,
})
}
}
parseSuccess = true
}
}
// 如果解析失败尝试OpenAI格式
if !parseSuccess {
if err = json.Unmarshal(body, &result); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": fmt.Sprintf("解析响应失败: %s", err.Error()),
})
return
}
}
var ids []string
for _, model := range result.Data {
id := model.ID
if channel.Type == constant.ChannelTypeGemini {
id = strings.TrimPrefix(id, "models/")
}
ids = append(ids, id)
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": ids,
})
}
func FixChannelsAbilities(c *gin.Context) {
success, fails, err := model.FixAbility()
if err != nil {
common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": gin.H{
"success": success,
"fails": fails,
},
})
}
func SearchChannels(c *gin.Context) {
keyword := c.Query("keyword")
group := c.Query("group")
modelKeyword := c.Query("model")
statusParam := c.Query("status")
statusFilter := parseStatusFilter(statusParam)
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
channelData := make([]*model.Channel, 0)
if enableTagMode {
tags, err := model.SearchTags(keyword, group, modelKeyword, idSort)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
for _, tag := range tags {
if tag != nil && *tag != "" {
tagChannel, err := model.GetChannelsByTag(*tag, idSort)
if err == nil {
channelData = append(channelData, tagChannel...)
}
}
}
} else {
channels, err := model.SearchChannels(keyword, group, modelKeyword, idSort)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
channelData = channels
}
if statusFilter == common.ChannelStatusEnabled || statusFilter == 0 {
filtered := make([]*model.Channel, 0, len(channelData))
for _, ch := range channelData {
if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
continue
}
if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
continue
}
filtered = append(filtered, ch)
}
channelData = filtered
}
// calculate type counts for search results
typeCounts := make(map[int64]int64)
for _, channel := range channelData {
typeCounts[int64(channel.Type)]++
}
typeParam := c.Query("type")
typeFilter := -1
if typeParam != "" {
if tp, err := strconv.Atoi(typeParam); err == nil {
typeFilter = tp
}
}
if typeFilter >= 0 {
filtered := make([]*model.Channel, 0, len(channelData))
for _, ch := range channelData {
if ch.Type == typeFilter {
filtered = append(filtered, ch)
}
}
channelData = filtered
}
page, _ := strconv.Atoi(c.DefaultQuery("p", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
if page < 1 {
page = 1
}
if pageSize <= 0 {
pageSize = 20
}
total := len(channelData)
startIdx := (page - 1) * pageSize
if startIdx > total {
startIdx = total
}
endIdx := startIdx + pageSize
if endIdx > total {
endIdx = total
}
pagedData := channelData[startIdx:endIdx]
for _, datum := range pagedData {
clearChannelInfo(datum)
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": gin.H{
"items": pagedData,
"total": total,
"type_counts": typeCounts,
},
})
return
}
func GetChannel(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
common.ApiError(c, err)
return
}
channel, err := model.GetChannelById(id, false)
if err != nil {
common.ApiError(c, err)
return
}
if channel != nil {
clearChannelInfo(channel)
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": channel,
})
return
}
// validateChannel 通用的渠道校验函数
func validateChannel(channel *model.Channel, isAdd bool) error {
// 校验 channel settings
if err := channel.ValidateSettings(); err != nil {
return fmt.Errorf("渠道额外设置[channel setting] 格式错误:%s", err.Error())
}
// 如果是添加操作,检查 channel 和 key 是否为空
if isAdd {
if channel == nil || channel.Key == "" {
return fmt.Errorf("channel cannot be empty")
}
// 检查模型名称长度是否超过 255
for _, m := range channel.GetModels() {
if len(m) > 255 {
return fmt.Errorf("模型名称过长: %s", m)
}
}
}
// VertexAI 特殊校验
if channel.Type == constant.ChannelTypeVertexAi {
if channel.Other == "" {
return fmt.Errorf("部署地区不能为空")
}
regionMap, err := common.StrToMap(channel.Other)
if err != nil {
return fmt.Errorf("部署地区必须是标准的Json格式例如{\"default\": \"us-central1\", \"region2\": \"us-east1\"}")
}
if regionMap["default"] == nil {
return fmt.Errorf("部署地区必须包含default字段")
}
}
return nil
}
type AddChannelRequest struct {
Mode string `json:"mode"`
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
Channel *model.Channel `json:"channel"`
}
func getVertexArrayKeys(keys string) ([]string, error) {
if keys == "" {
return nil, nil
}
var keyArray []interface{}
err := common.Unmarshal([]byte(keys), &keyArray)
if err != nil {
return nil, fmt.Errorf("批量添加 Vertex AI 必须使用标准的JsonArray格式例如[{key1}, {key2}...],请检查输入: %w", err)
}
cleanKeys := make([]string, 0, len(keyArray))
for _, key := range keyArray {
var keyStr string
switch v := key.(type) {
case string:
keyStr = strings.TrimSpace(v)
default:
bytes, err := json.Marshal(v)
if err != nil {
return nil, fmt.Errorf("Vertex AI key JSON 编码失败: %w", err)
}
keyStr = string(bytes)
}
if keyStr != "" {
cleanKeys = append(cleanKeys, keyStr)
}
}
if len(cleanKeys) == 0 {
return nil, fmt.Errorf("批量添加 Vertex AI 的 keys 不能为空")
}
return cleanKeys, nil
}
func AddChannel(c *gin.Context) {
addChannelRequest := AddChannelRequest{}
err := c.ShouldBindJSON(&addChannelRequest)
if err != nil {
common.ApiError(c, err)
return
}
// 使用统一的校验函数
if err := validateChannel(addChannelRequest.Channel, true); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
addChannelRequest.Channel.CreatedTime = common.GetTimestamp()
keys := make([]string, 0)
switch addChannelRequest.Mode {
case "multi_to_single":
addChannelRequest.Channel.ChannelInfo.IsMultiKey = true
addChannelRequest.Channel.ChannelInfo.MultiKeyMode = addChannelRequest.MultiKeyMode
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
array, err := getVertexArrayKeys(addChannelRequest.Channel.Key)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(array)
addChannelRequest.Channel.Key = strings.Join(array, "\n")
} else {
cleanKeys := make([]string, 0)
for _, key := range strings.Split(addChannelRequest.Channel.Key, "\n") {
if key == "" {
continue
}
key = strings.TrimSpace(key)
cleanKeys = append(cleanKeys, key)
}
addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(cleanKeys)
addChannelRequest.Channel.Key = strings.Join(cleanKeys, "\n")
}
keys = []string{addChannelRequest.Channel.Key}
case "batch":
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
// multi json
keys, err = getVertexArrayKeys(addChannelRequest.Channel.Key)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
keys = strings.Split(addChannelRequest.Channel.Key, "\n")
}
case "single":
keys = []string{addChannelRequest.Channel.Key}
default:
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "不支持的添加模式",
})
return
}
channels := make([]model.Channel, 0, len(keys))
for _, key := range keys {
if key == "" {
continue
}
localChannel := addChannelRequest.Channel
localChannel.Key = key
channels = append(channels, *localChannel)
}
err = model.BatchInsertChannels(channels)
if err != nil {
common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}
func DeleteChannel(c *gin.Context) {
id, _ := strconv.Atoi(c.Param("id"))
channel := model.Channel{Id: id}
err := channel.Delete()
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}
func DeleteDisabledChannel(c *gin.Context) {
rows, err := model.DeleteDisabledChannel()
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": rows,
})
return
}
type ChannelTag struct {
Tag string `json:"tag"`
NewTag *string `json:"new_tag"`
Priority *int64 `json:"priority"`
Weight *uint `json:"weight"`
ModelMapping *string `json:"model_mapping"`
Models *string `json:"models"`
Groups *string `json:"groups"`
}
func DisableTagChannels(c *gin.Context) {
channelTag := ChannelTag{}
err := c.ShouldBindJSON(&channelTag)
if err != nil || channelTag.Tag == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
err = model.DisableChannelByTag(channelTag.Tag)
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}
func EnableTagChannels(c *gin.Context) {
channelTag := ChannelTag{}
err := c.ShouldBindJSON(&channelTag)
if err != nil || channelTag.Tag == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
err = model.EnableChannelByTag(channelTag.Tag)
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}
func EditTagChannels(c *gin.Context) {
channelTag := ChannelTag{}
err := c.ShouldBindJSON(&channelTag)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
if channelTag.Tag == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "tag不能为空",
})
return
}
err = model.EditChannelByTag(channelTag.Tag, channelTag.NewTag, channelTag.ModelMapping, channelTag.Models, channelTag.Groups, channelTag.Priority, channelTag.Weight)
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}
type ChannelBatch struct {
Ids []int `json:"ids"`
Tag *string `json:"tag"`
}
func DeleteChannelBatch(c *gin.Context) {
channelBatch := ChannelBatch{}
err := c.ShouldBindJSON(&channelBatch)
if err != nil || len(channelBatch.Ids) == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
err = model.BatchDeleteChannels(channelBatch.Ids)
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": len(channelBatch.Ids),
})
return
}
type PatchChannel struct {
model.Channel
MultiKeyMode *string `json:"multi_key_mode"`
KeyMode *string `json:"key_mode"` // 多key模式下密钥覆盖或者追加
}
func UpdateChannel(c *gin.Context) {
channel := PatchChannel{}
err := c.ShouldBindJSON(&channel)
if err != nil {
common.ApiError(c, err)
return
}
// 使用统一的校验函数
if err := validateChannel(&channel.Channel, false); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
// Preserve existing ChannelInfo to ensure multi-key channels keep correct state even if the client does not send ChannelInfo in the request.
originChannel, err := model.GetChannelById(channel.Id, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
// Always copy the original ChannelInfo so that fields like IsMultiKey and MultiKeySize are retained.
channel.ChannelInfo = originChannel.ChannelInfo
// If the request explicitly specifies a new MultiKeyMode, apply it on top of the original info.
if channel.MultiKeyMode != nil && *channel.MultiKeyMode != "" {
channel.ChannelInfo.MultiKeyMode = constant.MultiKeyMode(*channel.MultiKeyMode)
}
// 处理多key模式下的密钥追加/覆盖逻辑
if channel.KeyMode != nil && channel.ChannelInfo.IsMultiKey {
switch *channel.KeyMode {
case "append":
// 追加模式:将新密钥添加到现有密钥列表
if originChannel.Key != "" {
var newKeys []string
var existingKeys []string
// 解析现有密钥
if strings.HasPrefix(strings.TrimSpace(originChannel.Key), "[") {
// JSON数组格式
var arr []json.RawMessage
if err := json.Unmarshal([]byte(strings.TrimSpace(originChannel.Key)), &arr); err == nil {
existingKeys = make([]string, len(arr))
for i, v := range arr {
existingKeys[i] = string(v)
}
}
} else {
// 换行分隔格式
existingKeys = strings.Split(strings.Trim(originChannel.Key, "\n"), "\n")
}
// 处理 Vertex AI 的特殊情况
if channel.Type == constant.ChannelTypeVertexAi {
// 尝试解析新密钥为JSON数组
if strings.HasPrefix(strings.TrimSpace(channel.Key), "[") {
array, err := getVertexArrayKeys(channel.Key)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "追加密钥解析失败: " + err.Error(),
})
return
}
newKeys = array
} else {
// 单个JSON密钥
newKeys = []string{channel.Key}
}
// 合并密钥
allKeys := append(existingKeys, newKeys...)
channel.Key = strings.Join(allKeys, "\n")
} else {
// 普通渠道的处理
inputKeys := strings.Split(channel.Key, "\n")
for _, key := range inputKeys {
key = strings.TrimSpace(key)
if key != "" {
newKeys = append(newKeys, key)
}
}
// 合并密钥
allKeys := append(existingKeys, newKeys...)
channel.Key = strings.Join(allKeys, "\n")
}
}
case "replace":
// 覆盖模式:直接使用新密钥(默认行为,不需要特殊处理)
}
}
err = channel.Update()
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
channel.Key = ""
clearChannelInfo(&channel.Channel)
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": channel,
})
return
}
func FetchModels(c *gin.Context) {
var req struct {
BaseURL string `json:"base_url"`
Type int `json:"type"`
Key string `json:"key"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": "Invalid request",
})
return
}
baseURL := req.BaseURL
if baseURL == "" {
baseURL = constant.ChannelBaseURLs[req.Type]
}
client := &http.Client{}
url := fmt.Sprintf("%s/v1/models", baseURL)
request, err := http.NewRequest("GET", url, nil)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": err.Error(),
})
return
}
// remove line breaks and extra spaces.
key := strings.TrimSpace(req.Key)
// If the key contains a line break, only take the first part.
key = strings.Split(key, "\n")[0]
request.Header.Set("Authorization", "Bearer "+key)
response, err := client.Do(request)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": err.Error(),
})
return
}
//check status code
if response.StatusCode != http.StatusOK {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": "Failed to fetch models",
})
return
}
defer response.Body.Close()
var result struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
}
if err := json.NewDecoder(response.Body).Decode(&result); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": err.Error(),
})
return
}
var models []string
for _, model := range result.Data {
models = append(models, model.ID)
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": models,
})
}
func BatchSetChannelTag(c *gin.Context) {
channelBatch := ChannelBatch{}
err := c.ShouldBindJSON(&channelBatch)
if err != nil || len(channelBatch.Ids) == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
err = model.BatchSetChannelTag(channelBatch.Ids, channelBatch.Tag)
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": len(channelBatch.Ids),
})
return
}
func GetTagModels(c *gin.Context) {
tag := c.Query("tag")
if tag == "" {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": "tag不能为空",
})
return
}
channels, err := model.GetChannelsByTag(tag, false) // Assuming false for idSort is fine here
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": err.Error(),
})
return
}
var longestModels string
maxLength := 0
// Find the longest models string among all channels with the given tag
for _, channel := range channels {
if channel.Models != "" {
currentModels := strings.Split(channel.Models, ",")
if len(currentModels) > maxLength {
maxLength = len(currentModels)
longestModels = channel.Models
}
}
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": longestModels,
})
return
}
// CopyChannel handles cloning an existing channel with its key.
// POST /api/channel/copy/:id
// Optional query params:
//
// suffix - string appended to the original name (default "_复制")
// reset_balance - bool, when true will reset balance & used_quota to 0 (default true)
func CopyChannel(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusOK, gin.H{"success": false, "message": "invalid id"})
return
}
suffix := c.DefaultQuery("suffix", "_复制")
resetBalance := true
if rbStr := c.DefaultQuery("reset_balance", "true"); rbStr != "" {
if v, err := strconv.ParseBool(rbStr); err == nil {
resetBalance = v
}
}
// fetch original channel with key
origin, err := model.GetChannelById(id, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
return
}
// clone channel
clone := *origin // shallow copy is sufficient as we will overwrite primitives
clone.Id = 0 // let DB auto-generate
clone.CreatedTime = common.GetTimestamp()
clone.Name = origin.Name + suffix
clone.TestTime = 0
clone.ResponseTime = 0
if resetBalance {
clone.Balance = 0
clone.UsedQuota = 0
}
// insert
if err := model.BatchInsertChannels([]model.Channel{clone}); err != nil {
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
return
}
model.InitChannelCache()
// success
c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": gin.H{"id": clone.Id}})
}
// MultiKeyManageRequest represents the request for multi-key management operations
type MultiKeyManageRequest struct {
ChannelId int `json:"channel_id"`
Action string `json:"action"` // "disable_key", "enable_key", "delete_disabled_keys", "get_key_status"
KeyIndex *int `json:"key_index,omitempty"` // for disable_key and enable_key actions
Page int `json:"page,omitempty"` // for get_key_status pagination
PageSize int `json:"page_size,omitempty"` // for get_key_status pagination
Status *int `json:"status,omitempty"` // for get_key_status filtering: 1=enabled, 2=manual_disabled, 3=auto_disabled, nil=all
}
// MultiKeyStatusResponse represents the response for key status query
type MultiKeyStatusResponse struct {
Keys []KeyStatus `json:"keys"`
Total int `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
TotalPages int `json:"total_pages"`
// Statistics
EnabledCount int `json:"enabled_count"`
ManualDisabledCount int `json:"manual_disabled_count"`
AutoDisabledCount int `json:"auto_disabled_count"`
}
type KeyStatus struct {
Index int `json:"index"`
Status int `json:"status"` // 1: enabled, 2: disabled
DisabledTime int64 `json:"disabled_time,omitempty"`
Reason string `json:"reason,omitempty"`
KeyPreview string `json:"key_preview"` // first 10 chars of key for identification
}
// ManageMultiKeys handles multi-key management operations
func ManageMultiKeys(c *gin.Context) {
request := MultiKeyManageRequest{}
err := c.ShouldBindJSON(&request)
if err != nil {
common.ApiError(c, err)
return
}
channel, err := model.GetChannelById(request.ChannelId, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "渠道不存在",
})
return
}
if !channel.ChannelInfo.IsMultiKey {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该渠道不是多密钥模式",
})
return
}
lock := model.GetChannelPollingLock(channel.Id)
lock.Lock()
defer lock.Unlock()
switch request.Action {
case "get_key_status":
keys := channel.GetKeys()
// Default pagination parameters
page := request.Page
pageSize := request.PageSize
if page <= 0 {
page = 1
}
if pageSize <= 0 {
pageSize = 50 // Default page size
}
// Statistics for all keys (unchanged by filtering)
var enabledCount, manualDisabledCount, autoDisabledCount int
// Build all key status data first
var allKeyStatusList []KeyStatus
for i, key := range keys {
status := 1 // default enabled
var disabledTime int64
var reason string
if channel.ChannelInfo.MultiKeyStatusList != nil {
if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists {
status = s
}
}
// Count for statistics (all keys)
switch status {
case 1:
enabledCount++
case 2:
manualDisabledCount++
case 3:
autoDisabledCount++
}
if status != 1 {
if channel.ChannelInfo.MultiKeyDisabledTime != nil {
disabledTime = channel.ChannelInfo.MultiKeyDisabledTime[i]
}
if channel.ChannelInfo.MultiKeyDisabledReason != nil {
reason = channel.ChannelInfo.MultiKeyDisabledReason[i]
}
}
// Create key preview (first 10 chars)
keyPreview := key
if len(key) > 10 {
keyPreview = key[:10] + "..."
}
allKeyStatusList = append(allKeyStatusList, KeyStatus{
Index: i,
Status: status,
DisabledTime: disabledTime,
Reason: reason,
KeyPreview: keyPreview,
})
}
// Apply status filter if specified
var filteredKeyStatusList []KeyStatus
if request.Status != nil {
for _, keyStatus := range allKeyStatusList {
if keyStatus.Status == *request.Status {
filteredKeyStatusList = append(filteredKeyStatusList, keyStatus)
}
}
} else {
filteredKeyStatusList = allKeyStatusList
}
// Calculate pagination based on filtered results
filteredTotal := len(filteredKeyStatusList)
totalPages := (filteredTotal + pageSize - 1) / pageSize
if totalPages == 0 {
totalPages = 1
}
if page > totalPages {
page = totalPages
}
// Calculate range for current page
start := (page - 1) * pageSize
end := start + pageSize
if end > filteredTotal {
end = filteredTotal
}
// Get the page data
var pageKeyStatusList []KeyStatus
if start < filteredTotal {
pageKeyStatusList = filteredKeyStatusList[start:end]
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": MultiKeyStatusResponse{
Keys: pageKeyStatusList,
Total: filteredTotal, // Total of filtered results
Page: page,
PageSize: pageSize,
TotalPages: totalPages,
EnabledCount: enabledCount, // Overall statistics
ManualDisabledCount: manualDisabledCount, // Overall statistics
AutoDisabledCount: autoDisabledCount, // Overall statistics
},
})
return
case "disable_key":
if request.KeyIndex == nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "未指定要禁用的密钥索引",
})
return
}
keyIndex := *request.KeyIndex
if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "密钥索引超出范围",
})
return
}
if channel.ChannelInfo.MultiKeyStatusList == nil {
channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
}
if channel.ChannelInfo.MultiKeyDisabledTime == nil {
channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
}
if channel.ChannelInfo.MultiKeyDisabledReason == nil {
channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
}
channel.ChannelInfo.MultiKeyStatusList[keyIndex] = 2 // disabled
err = channel.Update()
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "密钥已禁用",
})
return
case "enable_key":
if request.KeyIndex == nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "未指定要启用的密钥索引",
})
return
}
keyIndex := *request.KeyIndex
if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "密钥索引超出范围",
})
return
}
// 从状态列表中删除该密钥的记录,使其回到默认启用状态
if channel.ChannelInfo.MultiKeyStatusList != nil {
delete(channel.ChannelInfo.MultiKeyStatusList, keyIndex)
}
if channel.ChannelInfo.MultiKeyDisabledTime != nil {
delete(channel.ChannelInfo.MultiKeyDisabledTime, keyIndex)
}
if channel.ChannelInfo.MultiKeyDisabledReason != nil {
delete(channel.ChannelInfo.MultiKeyDisabledReason, keyIndex)
}
err = channel.Update()
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "密钥已启用",
})
return
case "enable_all_keys":
// 清空所有禁用状态,使所有密钥回到默认启用状态
var enabledCount int
if channel.ChannelInfo.MultiKeyStatusList != nil {
enabledCount = len(channel.ChannelInfo.MultiKeyStatusList)
}
channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
err = channel.Update()
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": fmt.Sprintf("已启用 %d 个密钥", enabledCount),
})
return
case "disable_all_keys":
// 禁用所有启用的密钥
if channel.ChannelInfo.MultiKeyStatusList == nil {
channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
}
if channel.ChannelInfo.MultiKeyDisabledTime == nil {
channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
}
if channel.ChannelInfo.MultiKeyDisabledReason == nil {
channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
}
var disabledCount int
for i := 0; i < channel.ChannelInfo.MultiKeySize; i++ {
status := 1 // default enabled
if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists {
status = s
}
// 只禁用当前启用的密钥
if status == 1 {
channel.ChannelInfo.MultiKeyStatusList[i] = 2 // disabled
disabledCount++
}
}
if disabledCount == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "没有可禁用的密钥",
})
return
}
err = channel.Update()
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": fmt.Sprintf("已禁用 %d 个密钥", disabledCount),
})
return
case "delete_disabled_keys":
keys := channel.GetKeys()
var remainingKeys []string
var deletedCount int
var newStatusList = make(map[int]int)
var newDisabledTime = make(map[int]int64)
var newDisabledReason = make(map[int]string)
newIndex := 0
for i, key := range keys {
status := 1 // default enabled
if channel.ChannelInfo.MultiKeyStatusList != nil {
if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists {
status = s
}
}
// 只删除自动禁用status == 3的密钥保留启用status == 1和手动禁用status == 2的密钥
if status == 3 {
deletedCount++
} else {
remainingKeys = append(remainingKeys, key)
// 保留非自动禁用密钥的状态信息,重新索引
if status != 1 {
newStatusList[newIndex] = status
if channel.ChannelInfo.MultiKeyDisabledTime != nil {
if t, exists := channel.ChannelInfo.MultiKeyDisabledTime[i]; exists {
newDisabledTime[newIndex] = t
}
}
if channel.ChannelInfo.MultiKeyDisabledReason != nil {
if r, exists := channel.ChannelInfo.MultiKeyDisabledReason[i]; exists {
newDisabledReason[newIndex] = r
}
}
}
newIndex++
}
}
if deletedCount == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "没有需要删除的自动禁用密钥",
})
return
}
// Update channel with remaining keys
channel.Key = strings.Join(remainingKeys, "\n")
channel.ChannelInfo.MultiKeySize = len(remainingKeys)
channel.ChannelInfo.MultiKeyStatusList = newStatusList
channel.ChannelInfo.MultiKeyDisabledTime = newDisabledTime
channel.ChannelInfo.MultiKeyDisabledReason = newDisabledReason
err = channel.Update()
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": fmt.Sprintf("已删除 %d 个自动禁用的密钥", deletedCount),
"data": deletedCount,
})
return
default:
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "不支持的操作",
})
return
}
}