package controller import ( "encoding/json" "fmt" "net/http" "one-api/common" "one-api/model" "strconv" "strings" "github.com/gin-gonic/gin" ) type OpenAIModel struct { ID string `json:"id"` Object string `json:"object"` Created int64 `json:"created"` OwnedBy string `json:"owned_by"` Permission []struct { ID string `json:"id"` Object string `json:"object"` Created int64 `json:"created"` AllowCreateEngine bool `json:"allow_create_engine"` AllowSampling bool `json:"allow_sampling"` AllowLogprobs bool `json:"allow_logprobs"` AllowSearchIndices bool `json:"allow_search_indices"` AllowView bool `json:"allow_view"` AllowFineTuning bool `json:"allow_fine_tuning"` Organization string `json:"organization"` Group string `json:"group"` IsBlocking bool `json:"is_blocking"` } `json:"permission"` Root string `json:"root"` Parent string `json:"parent"` } type OpenAIModelsResponse struct { Data []OpenAIModel `json:"data"` Success bool `json:"success"` } func GetAllChannels(c *gin.Context) { p, _ := strconv.Atoi(c.Query("p")) pageSize, _ := strconv.Atoi(c.Query("page_size")) if p < 0 { p = 0 } if pageSize < 0 { pageSize = common.ItemsPerPage } channelData := make([]*model.Channel, 0) idSort, _ := strconv.ParseBool(c.Query("id_sort")) enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode")) if enableTagMode { tags, err := model.GetPaginatedTags(p*pageSize, pageSize) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } for _, tag := range tags { if tag != nil && *tag != "" { tagChannel, err := model.GetChannelsByTag(*tag, idSort) if err == nil { channelData = append(channelData, tagChannel...) } } } } else { channels, err := model.GetAllChannels(p*pageSize, pageSize, false, idSort) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } channelData = channels } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": channelData, }) return } func FetchUpstreamModels(c *gin.Context) { id, err := strconv.Atoi(c.Param("id")) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } channel, err := model.GetChannelById(id, true) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } //if channel.Type != common.ChannelTypeOpenAI { // c.JSON(http.StatusOK, gin.H{ // "success": false, // "message": "仅支持 OpenAI 类型渠道", // }) // return //} baseURL := common.ChannelBaseURLs[channel.Type] if channel.GetBaseURL() != "" { baseURL = channel.GetBaseURL() } url := fmt.Sprintf("%s/v1/models", baseURL) body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } var result OpenAIModelsResponse if err = json.Unmarshal(body, &result); err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": fmt.Sprintf("解析响应失败: %s", err.Error()), }) return } var ids []string for _, model := range result.Data { ids = append(ids, model.ID) } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": ids, }) } func FixChannelsAbilities(c *gin.Context) { count, err := model.FixAbility() if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": count, }) } func SearchChannels(c *gin.Context) { keyword := c.Query("keyword") group := c.Query("group") modelKeyword := c.Query("model") idSort, _ := strconv.ParseBool(c.Query("id_sort")) enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode")) channelData := make([]*model.Channel, 0) if enableTagMode { tags, err := model.SearchTags(keyword, group, modelKeyword, idSort) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } for _, tag := range tags { if tag != nil && *tag != "" { tagChannel, err := model.GetChannelsByTag(*tag, idSort) if err == nil { channelData = append(channelData, tagChannel...) } } } } else { channels, err := model.SearchChannels(keyword, group, modelKeyword, idSort) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } channelData = channels } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": channelData, }) return } func GetChannel(c *gin.Context) { id, err := strconv.Atoi(c.Param("id")) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } channel, err := model.GetChannelById(id, false) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": channel, }) return } func AddChannel(c *gin.Context) { channel := model.Channel{} err := c.ShouldBindJSON(&channel) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } channel.CreatedTime = common.GetTimestamp() keys := strings.Split(channel.Key, "\n") if channel.Type == common.ChannelTypeVertexAi { if channel.Other == "" { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "部署地区不能为空", }) return } else { if common.IsJsonStr(channel.Other) { // must have default regionMap := common.StrToMap(channel.Other) if regionMap["default"] == nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "部署地区必须包含default字段", }) return } } } keys = []string{channel.Key} } channels := make([]model.Channel, 0, len(keys)) for _, key := range keys { if key == "" { continue } localChannel := channel localChannel.Key = key channels = append(channels, localChannel) } err = model.BatchInsertChannels(channels) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", }) return } func DeleteChannel(c *gin.Context) { id, _ := strconv.Atoi(c.Param("id")) channel := model.Channel{Id: id} err := channel.Delete() if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", }) return } func DeleteDisabledChannel(c *gin.Context) { rows, err := model.DeleteDisabledChannel() if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": rows, }) return } type ChannelTag struct { Tag string `json:"tag"` NewTag *string `json:"new_tag"` Priority *int64 `json:"priority"` Weight *uint `json:"weight"` ModelMapping *string `json:"model_mapping"` Models *string `json:"models"` Groups *string `json:"groups"` } func DisableTagChannels(c *gin.Context) { channelTag := ChannelTag{} err := c.ShouldBindJSON(&channelTag) if err != nil || channelTag.Tag == "" { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "参数错误", }) return } err = model.DisableChannelByTag(channelTag.Tag) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", }) return } func EnableTagChannels(c *gin.Context) { channelTag := ChannelTag{} err := c.ShouldBindJSON(&channelTag) if err != nil || channelTag.Tag == "" { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "参数错误", }) return } err = model.EnableChannelByTag(channelTag.Tag) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", }) return } func EditTagChannels(c *gin.Context) { channelTag := ChannelTag{} err := c.ShouldBindJSON(&channelTag) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "参数错误", }) return } if channelTag.Tag == "" { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "tag不能为空", }) return } err = model.EditChannelByTag(channelTag.Tag, channelTag.NewTag, channelTag.ModelMapping, channelTag.Models, channelTag.Groups, channelTag.Priority, channelTag.Weight) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", }) return } type ChannelBatch struct { Ids []int `json:"ids"` } func DeleteChannelBatch(c *gin.Context) { channelBatch := ChannelBatch{} err := c.ShouldBindJSON(&channelBatch) if err != nil || len(channelBatch.Ids) == 0 { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "参数错误", }) return } err = model.BatchDeleteChannels(channelBatch.Ids) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": len(channelBatch.Ids), }) return } func UpdateChannel(c *gin.Context) { channel := model.Channel{} err := c.ShouldBindJSON(&channel) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } if channel.Type == common.ChannelTypeVertexAi { if channel.Other == "" { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "部署地区不能为空", }) return } else { if common.IsJsonStr(channel.Other) { // must have default regionMap := common.StrToMap(channel.Other) if regionMap["default"] == nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "部署地区必须包含default字段", }) return } } } } err = channel.Update() if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": channel, }) return } func FetchModels(c *gin.Context) { var req struct { BaseURL string `json:"base_url"` Key string `json:"key"` } if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{ "success": false, "message": "Invalid request", }) return } baseURL := req.BaseURL if baseURL == "" { baseURL = "https://api.openai.com" } client := &http.Client{} url := fmt.Sprintf("%s/v1/models", baseURL) request, err := http.NewRequest("GET", url, nil) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "success": false, "message": err.Error(), }) return } request.Header.Set("Authorization", "Bearer "+req.Key) response, err := client.Do(request) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "success": false, "message": err.Error(), }) return } //check status code if response.StatusCode != http.StatusOK { c.JSON(http.StatusInternalServerError, gin.H{ "success": false, "message": "Failed to fetch models", }) return } defer response.Body.Close() var result struct { Data []struct { ID string `json:"id"` } `json:"data"` } if err := json.NewDecoder(response.Body).Decode(&result); err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "success": false, "message": err.Error(), }) return } var models []string for _, model := range result.Data { models = append(models, model.ID) } c.JSON(http.StatusOK, gin.H{ "success": true, "data": models, }) }