diff --git a/controller/model.go b/controller/model.go index 31a66b29..d03fdeb2 100644 --- a/controller/model.go +++ b/controller/model.go @@ -16,6 +16,7 @@ import ( "one-api/relay/channel/moonshot" relaycommon "one-api/relay/common" "one-api/setting" + "time" ) // https://platform.openai.com/docs/api-reference/models/list @@ -102,7 +103,7 @@ func init() { }) } -func ListModels(c *gin.Context) { +func ListModels(c *gin.Context, modelType int) { userOpenAiModels := make([]dto.OpenAIModels, 0) modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled) @@ -171,10 +172,41 @@ func ListModels(c *gin.Context) { } } } - c.JSON(200, gin.H{ - "success": true, - "data": userOpenAiModels, - }) + switch modelType { + case constant.ChannelTypeAnthropic: + useranthropicModels := make([]dto.AnthropicModel, len(userOpenAiModels)) + for i, model := range userOpenAiModels { + useranthropicModels[i] = dto.AnthropicModel{ + ID: model.Id, + CreatedAt: time.Unix(int64(model.Created), 0).UTC().Format(time.RFC3339), + DisplayName: model.Id, + Type: "model", + } + } + c.JSON(200, gin.H{ + "data": useranthropicModels, + "first_id": useranthropicModels[0].ID, + "has_more": false, + "last_id": useranthropicModels[len(useranthropicModels)-1].ID, + }) + case constant.ChannelTypeGemini: + userGeminiModels := make([]dto.GeminiModel, len(userOpenAiModels)) + for i, model := range userOpenAiModels { + userGeminiModels[i] = dto.GeminiModel{ + Name: model.Id, + DisplayName: model.Id, + } + } + c.JSON(200, gin.H{ + "models": userGeminiModels, + "nextPageToken": nil, + }) + default: + c.JSON(200, gin.H{ + "success": true, + "data": userOpenAiModels, + }) + } } func ChannelListModels(c *gin.Context) { @@ -198,10 +230,20 @@ func EnabledListModels(c *gin.Context) { }) } -func RetrieveModel(c *gin.Context) { +func RetrieveModel(c *gin.Context, modelType int) { modelId := c.Param("model") if aiModel, ok := openAIModelsMap[modelId]; ok { - c.JSON(200, aiModel) + switch modelType { + case constant.ChannelTypeAnthropic: + c.JSON(200, dto.AnthropicModel{ + ID: aiModel.Id, + CreatedAt: time.Unix(int64(aiModel.Created), 0).UTC().Format(time.RFC3339), + DisplayName: aiModel.Id, + Type: "model", + }) + default: + c.JSON(200, aiModel) + } } else { openAIError := dto.OpenAIError{ Message: fmt.Sprintf("The model '%s' does not exist", modelId), diff --git a/dto/pricing.go b/dto/pricing.go index 0f317d9d..bc024de3 100644 --- a/dto/pricing.go +++ b/dto/pricing.go @@ -2,6 +2,7 @@ package dto import "one-api/constant" +// 这里不好动就不动了,本来想独立出来的( type OpenAIModels struct { Id string `json:"id"` Object string `json:"object"` @@ -9,3 +10,26 @@ type OpenAIModels struct { OwnedBy string `json:"owned_by"` SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"` } + +type AnthropicModel struct { + ID string `json:"id"` + CreatedAt string `json:"created_at"` + DisplayName string `json:"display_name"` + Type string `json:"type"` +} + +type GeminiModel struct { + Name interface{} `json:"name"` + BaseModelId interface{} `json:"baseModelId"` + Version interface{} `json:"version"` + DisplayName interface{} `json:"displayName"` + Description interface{} `json:"description"` + InputTokenLimit interface{} `json:"inputTokenLimit"` + OutputTokenLimit interface{} `json:"outputTokenLimit"` + SupportedGenerationMethods []interface{} `json:"supportedGenerationMethods"` + Thinking interface{} `json:"thinking"` + Temperature interface{} `json:"temperature"` + MaxTemperature interface{} `json:"maxTemperature"` + TopP interface{} `json:"topP"` + TopK interface{} `json:"topK"` +} diff --git a/middleware/auth.go b/middleware/auth.go index 5f6e5d43..ee8d9241 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -192,16 +192,18 @@ func TokenAuth() func(c *gin.Context) { } c.Request.Header.Set("Authorization", "Bearer "+key) } + anthropicKey := c.Request.Header.Get("x-api-key") // 检查path包含/v1/messages - if strings.Contains(c.Request.URL.Path, "/v1/messages") { - // 从x-api-key中获取key - key := c.Request.Header.Get("x-api-key") - if key != "" { - c.Request.Header.Set("Authorization", "Bearer "+key) - } + // 或者是否 x-api-key 不为空且存在anthropic-version + // 谁知道有多少不符合规范没写anthropic-version的 + // 所以就这样随它去吧( + if strings.Contains(c.Request.URL.Path, "/v1/messages") || (anthropicKey != "" && c.Request.Header.Get("anthropic-version") != "") { + c.Request.Header.Set("Authorization", "Bearer "+anthropicKey) } // gemini api 从query中获取key - if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") || strings.HasPrefix(c.Request.URL.Path, "/v1/models/") { + if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models") || + strings.HasPrefix(c.Request.URL.Path, "/v1beta/openai/models") || + strings.HasPrefix(c.Request.URL.Path, "/v1/models/") { skKey := c.Query("key") if skKey != "" { c.Request.Header.Set("Authorization", "Bearer "+skKey) diff --git a/router/relay-router.go b/router/relay-router.go index 5b293dbd..cd656580 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -1,11 +1,11 @@ package router import ( + "github.com/gin-gonic/gin" + "one-api/constant" "one-api/controller" "one-api/middleware" "one-api/relay" - - "github.com/gin-gonic/gin" ) func SetRelayRouter(router *gin.Engine) { @@ -16,9 +16,43 @@ func SetRelayRouter(router *gin.Engine) { modelsRouter := router.Group("/v1/models") modelsRouter.Use(middleware.TokenAuth()) { - modelsRouter.GET("", controller.ListModels) - modelsRouter.GET("/:model", controller.RetrieveModel) + modelsRouter.GET("", func(c *gin.Context) { + switch { + case c.GetHeader("x-api-key") != "" && c.GetHeader("anthropic-version") != "": + controller.ListModels(c, constant.ChannelTypeAnthropic) + case c.GetHeader("x-goog-api-key") != "" || c.Query("key") != "": // 单独的适配 + controller.RetrieveModel(c, constant.ChannelTypeGemini) + default: + controller.ListModels(c, constant.ChannelTypeOpenAI) + } + }) + + modelsRouter.GET("/:model", func(c *gin.Context) { + switch { + case c.GetHeader("x-api-key") != "" && c.GetHeader("anthropic-version") != "": + controller.RetrieveModel(c, constant.ChannelTypeAnthropic) + default: + controller.RetrieveModel(c, constant.ChannelTypeOpenAI) + } + }) } + + geminiRouter := router.Group("/v1beta/models") + geminiRouter.Use(middleware.TokenAuth()) + { + geminiRouter.GET("", func(c *gin.Context) { + controller.ListModels(c, constant.ChannelTypeGemini) + }) + } + + geminiCompatibleRouter := router.Group("/v1beta/openai/models") + geminiCompatibleRouter.Use(middleware.TokenAuth()) + { + geminiCompatibleRouter.GET("", func(c *gin.Context) { + controller.ListModels(c, constant.ChannelTypeOpenAI) + }) + } + playgroundRouter := router.Group("/pg") playgroundRouter.Use(middleware.UserAuth(), middleware.Distribute()) {