Merge pull request #1547 from seefs001/feature/model_list

 feat: Enhance model listing and retrieval with support for Anthropic and Gemini models; refactor routes for better API key handling
This commit is contained in:
Seefs
2025-08-10 22:57:20 +08:00
committed by GitHub
4 changed files with 120 additions and 18 deletions

View File

@@ -16,6 +16,7 @@ import (
"one-api/relay/channel/moonshot" "one-api/relay/channel/moonshot"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/setting" "one-api/setting"
"time"
) )
// https://platform.openai.com/docs/api-reference/models/list // https://platform.openai.com/docs/api-reference/models/list
@@ -102,7 +103,7 @@ func init() {
}) })
} }
func ListModels(c *gin.Context) { func ListModels(c *gin.Context, modelType int) {
userOpenAiModels := make([]dto.OpenAIModels, 0) userOpenAiModels := make([]dto.OpenAIModels, 0)
modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled) modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
@@ -171,10 +172,41 @@ func ListModels(c *gin.Context) {
} }
} }
} }
c.JSON(200, gin.H{ switch modelType {
"success": true, case constant.ChannelTypeAnthropic:
"data": userOpenAiModels, useranthropicModels := make([]dto.AnthropicModel, len(userOpenAiModels))
}) for i, model := range userOpenAiModels {
useranthropicModels[i] = dto.AnthropicModel{
ID: model.Id,
CreatedAt: time.Unix(int64(model.Created), 0).UTC().Format(time.RFC3339),
DisplayName: model.Id,
Type: "model",
}
}
c.JSON(200, gin.H{
"data": useranthropicModels,
"first_id": useranthropicModels[0].ID,
"has_more": false,
"last_id": useranthropicModels[len(useranthropicModels)-1].ID,
})
case constant.ChannelTypeGemini:
userGeminiModels := make([]dto.GeminiModel, len(userOpenAiModels))
for i, model := range userOpenAiModels {
userGeminiModels[i] = dto.GeminiModel{
Name: model.Id,
DisplayName: model.Id,
}
}
c.JSON(200, gin.H{
"models": userGeminiModels,
"nextPageToken": nil,
})
default:
c.JSON(200, gin.H{
"success": true,
"data": userOpenAiModels,
})
}
} }
func ChannelListModels(c *gin.Context) { func ChannelListModels(c *gin.Context) {
@@ -198,10 +230,20 @@ func EnabledListModels(c *gin.Context) {
}) })
} }
func RetrieveModel(c *gin.Context) { func RetrieveModel(c *gin.Context, modelType int) {
modelId := c.Param("model") modelId := c.Param("model")
if aiModel, ok := openAIModelsMap[modelId]; ok { if aiModel, ok := openAIModelsMap[modelId]; ok {
c.JSON(200, aiModel) switch modelType {
case constant.ChannelTypeAnthropic:
c.JSON(200, dto.AnthropicModel{
ID: aiModel.Id,
CreatedAt: time.Unix(int64(aiModel.Created), 0).UTC().Format(time.RFC3339),
DisplayName: aiModel.Id,
Type: "model",
})
default:
c.JSON(200, aiModel)
}
} else { } else {
openAIError := dto.OpenAIError{ openAIError := dto.OpenAIError{
Message: fmt.Sprintf("The model '%s' does not exist", modelId), Message: fmt.Sprintf("The model '%s' does not exist", modelId),

View File

@@ -2,6 +2,7 @@ package dto
import "one-api/constant" import "one-api/constant"
// 这里不好动就不动了,本来想独立出来的(
type OpenAIModels struct { type OpenAIModels struct {
Id string `json:"id"` Id string `json:"id"`
Object string `json:"object"` Object string `json:"object"`
@@ -9,3 +10,26 @@ type OpenAIModels struct {
OwnedBy string `json:"owned_by"` OwnedBy string `json:"owned_by"`
SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"` SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
} }
type AnthropicModel struct {
ID string `json:"id"`
CreatedAt string `json:"created_at"`
DisplayName string `json:"display_name"`
Type string `json:"type"`
}
type GeminiModel struct {
Name interface{} `json:"name"`
BaseModelId interface{} `json:"baseModelId"`
Version interface{} `json:"version"`
DisplayName interface{} `json:"displayName"`
Description interface{} `json:"description"`
InputTokenLimit interface{} `json:"inputTokenLimit"`
OutputTokenLimit interface{} `json:"outputTokenLimit"`
SupportedGenerationMethods []interface{} `json:"supportedGenerationMethods"`
Thinking interface{} `json:"thinking"`
Temperature interface{} `json:"temperature"`
MaxTemperature interface{} `json:"maxTemperature"`
TopP interface{} `json:"topP"`
TopK interface{} `json:"topK"`
}

View File

@@ -192,16 +192,18 @@ func TokenAuth() func(c *gin.Context) {
} }
c.Request.Header.Set("Authorization", "Bearer "+key) c.Request.Header.Set("Authorization", "Bearer "+key)
} }
anthropicKey := c.Request.Header.Get("x-api-key")
// 检查path包含/v1/messages // 检查path包含/v1/messages
if strings.Contains(c.Request.URL.Path, "/v1/messages") { // 或者是否 x-api-key 不为空且存在anthropic-version
// 从x-api-key中获取key // 谁知道有多少不符合规范没写anthropic-version的
key := c.Request.Header.Get("x-api-key") // 所以就这样随它去吧(
if key != "" { if strings.Contains(c.Request.URL.Path, "/v1/messages") || (anthropicKey != "" && c.Request.Header.Get("anthropic-version") != "") {
c.Request.Header.Set("Authorization", "Bearer "+key) c.Request.Header.Set("Authorization", "Bearer "+anthropicKey)
}
} }
// gemini api 从query中获取key // gemini api 从query中获取key
if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") || strings.HasPrefix(c.Request.URL.Path, "/v1/models/") { if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models") ||
strings.HasPrefix(c.Request.URL.Path, "/v1beta/openai/models") ||
strings.HasPrefix(c.Request.URL.Path, "/v1/models/") {
skKey := c.Query("key") skKey := c.Query("key")
if skKey != "" { if skKey != "" {
c.Request.Header.Set("Authorization", "Bearer "+skKey) c.Request.Header.Set("Authorization", "Bearer "+skKey)

View File

@@ -1,11 +1,11 @@
package router package router
import ( import (
"github.com/gin-gonic/gin"
"one-api/constant"
"one-api/controller" "one-api/controller"
"one-api/middleware" "one-api/middleware"
"one-api/relay" "one-api/relay"
"github.com/gin-gonic/gin"
) )
func SetRelayRouter(router *gin.Engine) { func SetRelayRouter(router *gin.Engine) {
@@ -16,9 +16,43 @@ func SetRelayRouter(router *gin.Engine) {
modelsRouter := router.Group("/v1/models") modelsRouter := router.Group("/v1/models")
modelsRouter.Use(middleware.TokenAuth()) modelsRouter.Use(middleware.TokenAuth())
{ {
modelsRouter.GET("", controller.ListModels) modelsRouter.GET("", func(c *gin.Context) {
modelsRouter.GET("/:model", controller.RetrieveModel) switch {
case c.GetHeader("x-api-key") != "" && c.GetHeader("anthropic-version") != "":
controller.ListModels(c, constant.ChannelTypeAnthropic)
case c.GetHeader("x-goog-api-key") != "" || c.Query("key") != "": // 单独的适配
controller.RetrieveModel(c, constant.ChannelTypeGemini)
default:
controller.ListModels(c, constant.ChannelTypeOpenAI)
}
})
modelsRouter.GET("/:model", func(c *gin.Context) {
switch {
case c.GetHeader("x-api-key") != "" && c.GetHeader("anthropic-version") != "":
controller.RetrieveModel(c, constant.ChannelTypeAnthropic)
default:
controller.RetrieveModel(c, constant.ChannelTypeOpenAI)
}
})
} }
geminiRouter := router.Group("/v1beta/models")
geminiRouter.Use(middleware.TokenAuth())
{
geminiRouter.GET("", func(c *gin.Context) {
controller.ListModels(c, constant.ChannelTypeGemini)
})
}
geminiCompatibleRouter := router.Group("/v1beta/openai/models")
geminiCompatibleRouter.Use(middleware.TokenAuth())
{
geminiCompatibleRouter.GET("", func(c *gin.Context) {
controller.ListModels(c, constant.ChannelTypeOpenAI)
})
}
playgroundRouter := router.Group("/pg") playgroundRouter := router.Group("/pg")
playgroundRouter.Use(middleware.UserAuth(), middleware.Distribute()) playgroundRouter.Use(middleware.UserAuth(), middleware.Distribute())
{ {