Files
new-api/controller/channel.go
t0ng7u f1856fe4d2 🔧 **fix(model, controller): robust serialization for ChannelInfo & correct Vertex-AI key storage**
Summary
1. **model/channel.go**
   • Replaced the pointer-only `Value()` with a value-receiver implementation
   • GORM can now marshal both `ChannelInfo` and `*ChannelInfo`, eliminating
     `unsupported type model.ChannelInfo` runtime error.

2. **controller/channel.go**
   • Refactored `getVertexArrayKeys()` – every element is now
     - passed through `json.Marshal` when not already a string
     - trimmed & validated before insertion
   • Guarantees each service-account key is persisted as a **pure JSON string**
     instead of the previous `map[...]` dump.

Result
• Channel creation / update succeeds without SQL driver errors.
• Vertex-AI batch & multi-key uploads are stored in canonical JSON, ready for
  downstream SDKs to consume.
2025-07-07 01:31:41 +08:00

905 lines
21 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package controller
import (
"encoding/json"
"fmt"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/model"
"strconv"
"strings"
"github.com/gin-gonic/gin"
)
type OpenAIModel struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
Permission []struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
AllowCreateEngine bool `json:"allow_create_engine"`
AllowSampling bool `json:"allow_sampling"`
AllowLogprobs bool `json:"allow_logprobs"`
AllowSearchIndices bool `json:"allow_search_indices"`
AllowView bool `json:"allow_view"`
AllowFineTuning bool `json:"allow_fine_tuning"`
Organization string `json:"organization"`
Group string `json:"group"`
IsBlocking bool `json:"is_blocking"`
} `json:"permission"`
Root string `json:"root"`
Parent string `json:"parent"`
}
type OpenAIModelsResponse struct {
Data []OpenAIModel `json:"data"`
Success bool `json:"success"`
}
func parseStatusFilter(statusParam string) int {
switch strings.ToLower(statusParam) {
case "enabled", "1":
return common.ChannelStatusEnabled
case "disabled", "0":
return 0
default:
return -1
}
}
func GetAllChannels(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
pageSize, _ := strconv.Atoi(c.Query("page_size"))
if p < 1 {
p = 1
}
if pageSize < 1 {
pageSize = common.ItemsPerPage
}
channelData := make([]*model.Channel, 0)
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
statusParam := c.Query("status")
// statusFilter: -1 all, 1 enabled, 0 disabled (include auto & manual)
statusFilter := parseStatusFilter(statusParam)
// type filter
typeStr := c.Query("type")
typeFilter := -1
if typeStr != "" {
if t, err := strconv.Atoi(typeStr); err == nil {
typeFilter = t
}
}
var total int64
if enableTagMode {
tags, err := model.GetPaginatedTags((p-1)*pageSize, pageSize)
if err != nil {
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
return
}
for _, tag := range tags {
if tag == nil || *tag == "" {
continue
}
tagChannels, err := model.GetChannelsByTag(*tag, idSort)
if err != nil {
continue
}
filtered := make([]*model.Channel, 0)
for _, ch := range tagChannels {
if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
continue
}
if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
continue
}
if typeFilter >= 0 && ch.Type != typeFilter {
continue
}
filtered = append(filtered, ch)
}
channelData = append(channelData, filtered...)
}
total, _ = model.CountAllTags()
} else {
baseQuery := model.DB.Model(&model.Channel{})
if typeFilter >= 0 {
baseQuery = baseQuery.Where("type = ?", typeFilter)
}
if statusFilter == common.ChannelStatusEnabled {
baseQuery = baseQuery.Where("status = ?", common.ChannelStatusEnabled)
} else if statusFilter == 0 {
baseQuery = baseQuery.Where("status != ?", common.ChannelStatusEnabled)
}
baseQuery.Count(&total)
order := "priority desc"
if idSort {
order = "id desc"
}
err := baseQuery.Order(order).Limit(pageSize).Offset((p - 1) * pageSize).Omit("key").Find(&channelData).Error
if err != nil {
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
return
}
}
countQuery := model.DB.Model(&model.Channel{})
if statusFilter == common.ChannelStatusEnabled {
countQuery = countQuery.Where("status = ?", common.ChannelStatusEnabled)
} else if statusFilter == 0 {
countQuery = countQuery.Where("status != ?", common.ChannelStatusEnabled)
}
var results []struct {
Type int64
Count int64
}
_ = countQuery.Select("type, count(*) as count").Group("type").Find(&results).Error
typeCounts := make(map[int64]int64)
for _, r := range results {
typeCounts[r.Type] = r.Count
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": gin.H{
"items": channelData,
"total": total,
"page": p,
"page_size": pageSize,
"type_counts": typeCounts,
},
})
return
}
func FetchUpstreamModels(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
channel, err := model.GetChannelById(id, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
baseURL := constant.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() != "" {
baseURL = channel.GetBaseURL()
}
url := fmt.Sprintf("%s/v1/models", baseURL)
switch channel.Type {
case constant.ChannelTypeGemini:
url = fmt.Sprintf("%s/v1beta/openai/models", baseURL)
case constant.ChannelTypeAli:
url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
}
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
var result OpenAIModelsResponse
if err = json.Unmarshal(body, &result); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": fmt.Sprintf("解析响应失败: %s", err.Error()),
})
return
}
var ids []string
for _, model := range result.Data {
id := model.ID
if channel.Type == constant.ChannelTypeGemini {
id = strings.TrimPrefix(id, "models/")
}
ids = append(ids, id)
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": ids,
})
}
func FixChannelsAbilities(c *gin.Context) {
count, err := model.FixAbility()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": count,
})
}
func SearchChannels(c *gin.Context) {
keyword := c.Query("keyword")
group := c.Query("group")
modelKeyword := c.Query("model")
statusParam := c.Query("status")
statusFilter := parseStatusFilter(statusParam)
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
channelData := make([]*model.Channel, 0)
if enableTagMode {
tags, err := model.SearchTags(keyword, group, modelKeyword, idSort)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
for _, tag := range tags {
if tag != nil && *tag != "" {
tagChannel, err := model.GetChannelsByTag(*tag, idSort)
if err == nil {
channelData = append(channelData, tagChannel...)
}
}
}
} else {
channels, err := model.SearchChannels(keyword, group, modelKeyword, idSort)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
channelData = channels
}
if statusFilter == common.ChannelStatusEnabled || statusFilter == 0 {
filtered := make([]*model.Channel, 0, len(channelData))
for _, ch := range channelData {
if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
continue
}
if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
continue
}
filtered = append(filtered, ch)
}
channelData = filtered
}
// calculate type counts for search results
typeCounts := make(map[int64]int64)
for _, channel := range channelData {
typeCounts[int64(channel.Type)]++
}
typeParam := c.Query("type")
typeFilter := -1
if typeParam != "" {
if tp, err := strconv.Atoi(typeParam); err == nil {
typeFilter = tp
}
}
if typeFilter >= 0 {
filtered := make([]*model.Channel, 0, len(channelData))
for _, ch := range channelData {
if ch.Type == typeFilter {
filtered = append(filtered, ch)
}
}
channelData = filtered
}
page, _ := strconv.Atoi(c.DefaultQuery("p", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
if page < 1 {
page = 1
}
if pageSize <= 0 {
pageSize = 20
}
total := len(channelData)
startIdx := (page - 1) * pageSize
if startIdx > total {
startIdx = total
}
endIdx := startIdx + pageSize
if endIdx > total {
endIdx = total
}
pagedData := channelData[startIdx:endIdx]
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": gin.H{
"items": pagedData,
"total": total,
"type_counts": typeCounts,
},
})
return
}
func GetChannel(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
channel, err := model.GetChannelById(id, false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": channel,
})
return
}
type AddChannelRequest struct {
Mode string `json:"mode"`
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
Channel *model.Channel `json:"channel"`
}
func getVertexArrayKeys(keys string) ([]string, error) {
if keys == "" {
return nil, nil
}
var keyArray []interface{}
err := common.UnmarshalJson([]byte(keys), &keyArray)
if err != nil {
return nil, fmt.Errorf("批量添加 Vertex AI 必须使用标准的JsonArray格式例如[{key1}, {key2}...],请检查输入: %w", err)
}
cleanKeys := make([]string, 0, len(keyArray))
for _, key := range keyArray {
var keyStr string
switch v := key.(type) {
case string:
keyStr = strings.TrimSpace(v)
default:
bytes, err := json.Marshal(v)
if err != nil {
return nil, fmt.Errorf("Vertex AI key JSON 编码失败: %w", err)
}
keyStr = string(bytes)
}
if keyStr != "" {
cleanKeys = append(cleanKeys, keyStr)
}
}
if len(cleanKeys) == 0 {
return nil, fmt.Errorf("批量添加 Vertex AI 的 keys 不能为空")
}
return cleanKeys, nil
}
func AddChannel(c *gin.Context) {
addChannelRequest := AddChannelRequest{}
err := c.ShouldBindJSON(&addChannelRequest)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
if addChannelRequest.Channel == nil || addChannelRequest.Channel.Key == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "channel cannot be empty",
})
return
}
// Validate the length of the model name
for _, m := range addChannelRequest.Channel.GetModels() {
if len(m) > 255 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": fmt.Sprintf("模型名称过长: %s", m),
})
return
}
}
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
if addChannelRequest.Channel.Other == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "部署地区不能为空",
})
return
} else {
regionMap, err := common.StrToMap(addChannelRequest.Channel.Other)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "部署地区必须是标准的Json格式例如{\"default\": \"us-central1\", \"region2\": \"us-east1\"}",
})
return
}
if regionMap["default"] == nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "部署地区必须包含default字段",
})
return
}
}
}
addChannelRequest.Channel.CreatedTime = common.GetTimestamp()
keys := make([]string, 0)
switch addChannelRequest.Mode {
case "multi_to_single":
addChannelRequest.Channel.ChannelInfo.IsMultiKey = true
addChannelRequest.Channel.ChannelInfo.MultiKeyMode = addChannelRequest.MultiKeyMode
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
array, err := getVertexArrayKeys(addChannelRequest.Channel.Key)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
addChannelRequest.Channel.Key = strings.Join(array, "\n")
} else {
cleanKeys := make([]string, 0)
for _, key := range strings.Split(addChannelRequest.Channel.Key, "\n") {
if key == "" {
continue
}
key = strings.TrimSpace(key)
cleanKeys = append(cleanKeys, key)
}
addChannelRequest.Channel.Key = strings.Join(cleanKeys, "\n")
}
keys = []string{addChannelRequest.Channel.Key}
case "batch":
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
// multi json
keys, err = getVertexArrayKeys(addChannelRequest.Channel.Key)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
keys = strings.Split(addChannelRequest.Channel.Key, "\n")
}
case "single":
keys = []string{addChannelRequest.Channel.Key}
default:
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "不支持的添加模式",
})
return
}
channels := make([]model.Channel, 0, len(keys))
for _, key := range keys {
if key == "" {
continue
}
localChannel := addChannelRequest.Channel
localChannel.Key = key
channels = append(channels, *localChannel)
}
err = model.BatchInsertChannels(channels)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}
func DeleteChannel(c *gin.Context) {
id, _ := strconv.Atoi(c.Param("id"))
channel := model.Channel{Id: id}
err := channel.Delete()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}
func DeleteDisabledChannel(c *gin.Context) {
rows, err := model.DeleteDisabledChannel()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": rows,
})
return
}
type ChannelTag struct {
Tag string `json:"tag"`
NewTag *string `json:"new_tag"`
Priority *int64 `json:"priority"`
Weight *uint `json:"weight"`
ModelMapping *string `json:"model_mapping"`
Models *string `json:"models"`
Groups *string `json:"groups"`
}
func DisableTagChannels(c *gin.Context) {
channelTag := ChannelTag{}
err := c.ShouldBindJSON(&channelTag)
if err != nil || channelTag.Tag == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
err = model.DisableChannelByTag(channelTag.Tag)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}
func EnableTagChannels(c *gin.Context) {
channelTag := ChannelTag{}
err := c.ShouldBindJSON(&channelTag)
if err != nil || channelTag.Tag == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
err = model.EnableChannelByTag(channelTag.Tag)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}
func EditTagChannels(c *gin.Context) {
channelTag := ChannelTag{}
err := c.ShouldBindJSON(&channelTag)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
if channelTag.Tag == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "tag不能为空",
})
return
}
err = model.EditChannelByTag(channelTag.Tag, channelTag.NewTag, channelTag.ModelMapping, channelTag.Models, channelTag.Groups, channelTag.Priority, channelTag.Weight)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}
type ChannelBatch struct {
Ids []int `json:"ids"`
Tag *string `json:"tag"`
}
func DeleteChannelBatch(c *gin.Context) {
channelBatch := ChannelBatch{}
err := c.ShouldBindJSON(&channelBatch)
if err != nil || len(channelBatch.Ids) == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
err = model.BatchDeleteChannels(channelBatch.Ids)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": len(channelBatch.Ids),
})
return
}
func UpdateChannel(c *gin.Context) {
channel := model.Channel{}
err := c.ShouldBindJSON(&channel)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
if channel.Type == constant.ChannelTypeVertexAi {
if channel.Other == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "部署地区不能为空",
})
return
} else {
regionMap, err := common.StrToMap(channel.Other)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "部署地区必须是标准的Json格式例如{\"default\": \"us-central1\", \"region2\": \"us-east1\"}",
})
return
}
if regionMap["default"] == nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "部署地区必须包含default字段",
})
return
}
}
}
err = channel.Update()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
channel.Key = ""
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": channel,
})
return
}
func FetchModels(c *gin.Context) {
var req struct {
BaseURL string `json:"base_url"`
Type int `json:"type"`
Key string `json:"key"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": "Invalid request",
})
return
}
baseURL := req.BaseURL
if baseURL == "" {
baseURL = constant.ChannelBaseURLs[req.Type]
}
client := &http.Client{}
url := fmt.Sprintf("%s/v1/models", baseURL)
request, err := http.NewRequest("GET", url, nil)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": err.Error(),
})
return
}
// remove line breaks and extra spaces.
key := strings.TrimSpace(req.Key)
// If the key contains a line break, only take the first part.
key = strings.Split(key, "\n")[0]
request.Header.Set("Authorization", "Bearer "+key)
response, err := client.Do(request)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": err.Error(),
})
return
}
//check status code
if response.StatusCode != http.StatusOK {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": "Failed to fetch models",
})
return
}
defer response.Body.Close()
var result struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
}
if err := json.NewDecoder(response.Body).Decode(&result); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": err.Error(),
})
return
}
var models []string
for _, model := range result.Data {
models = append(models, model.ID)
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": models,
})
}
func BatchSetChannelTag(c *gin.Context) {
channelBatch := ChannelBatch{}
err := c.ShouldBindJSON(&channelBatch)
if err != nil || len(channelBatch.Ids) == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
err = model.BatchSetChannelTag(channelBatch.Ids, channelBatch.Tag)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": len(channelBatch.Ids),
})
return
}
func GetTagModels(c *gin.Context) {
tag := c.Query("tag")
if tag == "" {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": "tag不能为空",
})
return
}
channels, err := model.GetChannelsByTag(tag, false) // Assuming false for idSort is fine here
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": err.Error(),
})
return
}
var longestModels string
maxLength := 0
// Find the longest models string among all channels with the given tag
for _, channel := range channels {
if channel.Models != "" {
currentModels := strings.Split(channel.Models, ",")
if len(currentModels) > maxLength {
maxLength = len(currentModels)
longestModels = channel.Models
}
}
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": longestModels,
})
return
}