diff --git a/controller/channel.go b/controller/channel.go index d272d7b3..d5c21db1 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -97,6 +97,7 @@ func FetchUpstreamModels(c *gin.Context) { }) return } + channel, err := model.GetChannelById(id, true) if err != nil { c.JSON(http.StatusOK, gin.H{ @@ -105,34 +106,35 @@ func FetchUpstreamModels(c *gin.Context) { }) return } - if channel.Type != common.ChannelTypeOpenAI { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": "仅支持 OpenAI 类型渠道", - }) - return + + //if channel.Type != common.ChannelTypeOpenAI { + // c.JSON(http.StatusOK, gin.H{ + // "success": false, + // "message": "仅支持 OpenAI 类型渠道", + // }) + // return + //} + baseURL := common.ChannelBaseURLs[channel.Type] + if channel.GetBaseURL() == "" { + channel.BaseURL = &baseURL } - url := fmt.Sprintf("%s/v1/models", *channel.BaseURL) + url := fmt.Sprintf("%s/v1/models", baseURL) body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) + return } - result := OpenAIModelsResponse{} - err = json.Unmarshal(body, &result) - if err != nil { + + var result OpenAIModelsResponse + if err = json.Unmarshal(body, &result); err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, - "message": err.Error(), - }) - } - if !result.Success { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": "上游返回错误", + "message": fmt.Sprintf("解析响应失败: %s", err.Error()), }) + return } var ids []string @@ -492,3 +494,79 @@ func UpdateChannel(c *gin.Context) { }) return } + +func FetchModels(c *gin.Context) { + var req struct { + BaseURL string `json:"base_url"` + Key string `json:"key"` + } + + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "success": false, + "message": "Invalid request", + }) + return + } + + baseURL := req.BaseURL + if baseURL == "" { + baseURL = "https://api.openai.com" + } + + client := &http.Client{} + url := fmt.Sprintf("%s/v1/models", baseURL) + + request, err := http.NewRequest("GET", url, nil) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + request.Header.Set("Authorization", "Bearer "+req.Key) + + response, err := client.Do(request) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + //check status code + if response.StatusCode != http.StatusOK { + c.JSON(http.StatusInternalServerError, gin.H{ + "success": false, + "message": "Failed to fetch models", + }) + return + } + defer response.Body.Close() + + var result struct { + Data []struct { + ID string `json:"id"` + } `json:"data"` + } + + if err := json.NewDecoder(response.Body).Decode(&result); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + var models []string + for _, model := range result.Data { + models = append(models, model.ID) + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "data": models, + }) +} diff --git a/router/api-router.go b/router/api-router.go index 81a1341b..a64bcf52 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -98,6 +98,7 @@ func SetApiRouter(router *gin.Engine) { channelRoute.POST("/batch", controller.DeleteChannelBatch) channelRoute.POST("/fix", controller.FixChannelsAbilities) channelRoute.GET("/fetch_models/:id", controller.FetchUpstreamModels) + channelRoute.POST("/fetch_models", controller.FetchModels) } tokenRoute := apiRouter.Group("/token") diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 3e387a26..457dff42 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -193,14 +193,16 @@ const EditChannel = (props) => { const fetchUpstreamModelList = async (name) => { - if (inputs['type'] !== 1) { - showError(t('仅支持 OpenAI 接口格式')); - return; - } + // if (inputs['type'] !== 1) { + // showError(t('仅支持 OpenAI 接口格式')); + // return; + // } setLoading(true); const models = inputs['models'] || []; let err = false; + if (isEdit) { + // 如果是编辑模式,使用已有的channel id获取模型列表 const res = await API.get('/api/channel/fetch_models/' + channelId); if (res.data && res.data?.success) { models.push(...res.data.data); @@ -208,30 +210,29 @@ const EditChannel = (props) => { err = true; } } else { + // 如果是新建模式,通过后端代理获取模型列表 if (!inputs?.['key']) { showError(t('请填写密钥')); err = true; } else { try { - const host = new URL((inputs['base_url'] || 'https://api.openai.com')); - - const url = `https://${host.hostname}/v1/models`; - const key = inputs['key']; - const res = await axios.get(url, { - headers: { - 'Authorization': `Bearer ${key}` - } + const res = await API.post('/api/channel/fetch_models', { + base_url: inputs['base_url'], + key: inputs['key'] }); - if (res.data) { - models.push(...res.data.data.map((model) => model.id)); + + if (res.data && res.data.success) { + models.push(...res.data.data); } else { err = true; } } catch (error) { + console.error('Error fetching models:', error); err = true; } } } + if (!err) { handleInputChange(name, Array.from(new Set(models))); showSuccess(t('获取模型列表成功')); @@ -638,7 +639,7 @@ const EditChannel = (props) => { {inputs.type === 21 && ( <>
- 知识库 ID: + ��识库 ID: