diff --git a/controller/channel.go b/controller/channel.go index 513e3024..d9e4d422 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -36,11 +36,30 @@ type OpenAIModel struct { Parent string `json:"parent"` } +type GoogleOpenAICompatibleModels []struct { + Name string `json:"name"` + Version string `json:"version"` + DisplayName string `json:"displayName"` + Description string `json:"description,omitempty"` + InputTokenLimit int `json:"inputTokenLimit"` + OutputTokenLimit int `json:"outputTokenLimit"` + SupportedGenerationMethods []string `json:"supportedGenerationMethods"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"topP,omitempty"` + TopK int `json:"topK,omitempty"` + MaxTemperature int `json:"maxTemperature,omitempty"` +} + type OpenAIModelsResponse struct { Data []OpenAIModel `json:"data"` Success bool `json:"success"` } +type GoogleOpenAICompatibleResponse struct { + Models []GoogleOpenAICompatibleModels `json:"models"` + NextPageToken string `json:"nextPageToken"` +} + func parseStatusFilter(statusParam string) int { switch strings.ToLower(statusParam) { case "enabled", "1": @@ -168,26 +187,59 @@ func FetchUpstreamModels(c *gin.Context) { if channel.GetBaseURL() != "" { baseURL = channel.GetBaseURL() } - url := fmt.Sprintf("%s/v1/models", baseURL) + + var url string switch channel.Type { case constant.ChannelTypeGemini: - url = fmt.Sprintf("%s/v1beta/openai/models", baseURL) + // curl https://example.com/v1beta/models?key=$GEMINI_API_KEY + url = fmt.Sprintf("%s/v1beta/openai/models?key=%s", baseURL, channel.Key) case constant.ChannelTypeAli: url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL) + default: + url = fmt.Sprintf("%s/v1/models", baseURL) + } + + // 获取响应体 - 根据渠道类型决定是否添加 AuthHeader + var body []byte + if channel.Type == constant.ChannelTypeGemini { + body, err = GetResponseBody("GET", url, channel, nil) // I don't know why, but Gemini requires no AuthHeader + } else { + body, err = GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) } - body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) if err != nil { common.ApiError(c, err) return } var result OpenAIModelsResponse - if err = json.Unmarshal(body, &result); err != nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": fmt.Sprintf("解析响应失败: %s", err.Error()), - }) - return + var parseSuccess bool + + // 适配特殊格式 + switch channel.Type { + case constant.ChannelTypeGemini: + var googleResult GoogleOpenAICompatibleResponse + if err = json.Unmarshal(body, &googleResult); err == nil { + // 转换Google格式到OpenAI格式 + for _, model := range googleResult.Models { + for _, gModel := range model { + result.Data = append(result.Data, OpenAIModel{ + ID: gModel.Name, + }) + } + } + parseSuccess = true + } + } + + // 如果解析失败,尝试OpenAI格式 + if !parseSuccess { + if err = json.Unmarshal(body, &result); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": fmt.Sprintf("解析响应失败: %s", err.Error()), + }) + return + } } var ids []string