feat: Add FetchModels endpoint and refactor FetchUpstreamModels
- Introduced a new `FetchModels` endpoint to retrieve model IDs from a specified base URL and API key, enhancing flexibility for different channel types. - Refactored `FetchUpstreamModels` to simplify base URL handling and improve error messages during response parsing. - Updated API routes to include the new endpoint and adjusted the frontend to utilize the new fetch mechanism for model lists. - Removed outdated checks for channel type in the frontend, streamlining the model fetching process.
This commit is contained in:
@@ -97,6 +97,7 @@ func FetchUpstreamModels(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
channel, err := model.GetChannelById(id, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -105,34 +106,35 @@ func FetchUpstreamModels(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
if channel.Type != common.ChannelTypeOpenAI {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "仅支持 OpenAI 类型渠道",
|
||||
})
|
||||
return
|
||||
|
||||
//if channel.Type != common.ChannelTypeOpenAI {
|
||||
// c.JSON(http.StatusOK, gin.H{
|
||||
// "success": false,
|
||||
// "message": "仅支持 OpenAI 类型渠道",
|
||||
// })
|
||||
// return
|
||||
//}
|
||||
baseURL := common.ChannelBaseURLs[channel.Type]
|
||||
if channel.GetBaseURL() == "" {
|
||||
channel.BaseURL = &baseURL
|
||||
}
|
||||
url := fmt.Sprintf("%s/v1/models", *channel.BaseURL)
|
||||
url := fmt.Sprintf("%s/v1/models", baseURL)
|
||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
result := OpenAIModelsResponse{}
|
||||
err = json.Unmarshal(body, &result)
|
||||
if err != nil {
|
||||
|
||||
var result OpenAIModelsResponse
|
||||
if err = json.Unmarshal(body, &result); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
}
|
||||
if !result.Success {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "上游返回错误",
|
||||
"message": fmt.Sprintf("解析响应失败: %s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var ids []string
|
||||
@@ -492,3 +494,79 @@ func UpdateChannel(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func FetchModels(c *gin.Context) {
|
||||
var req struct {
|
||||
BaseURL string `json:"base_url"`
|
||||
Key string `json:"key"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"message": "Invalid request",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
baseURL := req.BaseURL
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.openai.com"
|
||||
}
|
||||
|
||||
client := &http.Client{}
|
||||
url := fmt.Sprintf("%s/v1/models", baseURL)
|
||||
|
||||
request, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
request.Header.Set("Authorization", "Bearer "+req.Key)
|
||||
|
||||
response, err := client.Do(request)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
//check status code
|
||||
if response.StatusCode != http.StatusOK {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"success": false,
|
||||
"message": "Failed to fetch models",
|
||||
})
|
||||
return
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
var result struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(response.Body).Decode(&result); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var models []string
|
||||
for _, model := range result.Data {
|
||||
models = append(models, model.ID)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": models,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -98,6 +98,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
channelRoute.POST("/batch", controller.DeleteChannelBatch)
|
||||
channelRoute.POST("/fix", controller.FixChannelsAbilities)
|
||||
channelRoute.GET("/fetch_models/:id", controller.FetchUpstreamModels)
|
||||
channelRoute.POST("/fetch_models", controller.FetchModels)
|
||||
|
||||
}
|
||||
tokenRoute := apiRouter.Group("/token")
|
||||
|
||||
@@ -193,14 +193,16 @@ const EditChannel = (props) => {
|
||||
|
||||
|
||||
const fetchUpstreamModelList = async (name) => {
|
||||
if (inputs['type'] !== 1) {
|
||||
showError(t('仅支持 OpenAI 接口格式'));
|
||||
return;
|
||||
}
|
||||
// if (inputs['type'] !== 1) {
|
||||
// showError(t('仅支持 OpenAI 接口格式'));
|
||||
// return;
|
||||
// }
|
||||
setLoading(true);
|
||||
const models = inputs['models'] || [];
|
||||
let err = false;
|
||||
|
||||
if (isEdit) {
|
||||
// 如果是编辑模式,使用已有的channel id获取模型列表
|
||||
const res = await API.get('/api/channel/fetch_models/' + channelId);
|
||||
if (res.data && res.data?.success) {
|
||||
models.push(...res.data.data);
|
||||
@@ -208,30 +210,29 @@ const EditChannel = (props) => {
|
||||
err = true;
|
||||
}
|
||||
} else {
|
||||
// 如果是新建模式,通过后端代理获取模型列表
|
||||
if (!inputs?.['key']) {
|
||||
showError(t('请填写密钥'));
|
||||
err = true;
|
||||
} else {
|
||||
try {
|
||||
const host = new URL((inputs['base_url'] || 'https://api.openai.com'));
|
||||
|
||||
const url = `https://${host.hostname}/v1/models`;
|
||||
const key = inputs['key'];
|
||||
const res = await axios.get(url, {
|
||||
headers: {
|
||||
'Authorization': `Bearer ${key}`
|
||||
}
|
||||
const res = await API.post('/api/channel/fetch_models', {
|
||||
base_url: inputs['base_url'],
|
||||
key: inputs['key']
|
||||
});
|
||||
if (res.data) {
|
||||
models.push(...res.data.data.map((model) => model.id));
|
||||
|
||||
if (res.data && res.data.success) {
|
||||
models.push(...res.data.data);
|
||||
} else {
|
||||
err = true;
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error fetching models:', error);
|
||||
err = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!err) {
|
||||
handleInputChange(name, Array.from(new Set(models)));
|
||||
showSuccess(t('获取模型列表成功'));
|
||||
@@ -638,7 +639,7 @@ const EditChannel = (props) => {
|
||||
{inputs.type === 21 && (
|
||||
<>
|
||||
<div style={{ marginTop: 10 }}>
|
||||
<Typography.Text strong>知识库 ID:</Typography.Text>
|
||||
<Typography.Text strong><EFBFBD><EFBFBD>识库 ID:</Typography.Text>
|
||||
</div>
|
||||
<Input
|
||||
label="知识库 ID"
|
||||
|
||||
Reference in New Issue
Block a user