diff --git a/common/endpoint_defaults.go b/common/endpoint_defaults.go index 1dfe1dc9..ffc26350 100644 --- a/common/endpoint_defaults.go +++ b/common/endpoint_defaults.go @@ -11,22 +11,22 @@ import "one-api/constant" // 例如:{"path":"/v1/chat/completions","method":"POST"} type EndpointInfo struct { - Path string `json:"path"` - Method string `json:"method"` + Path string `json:"path"` + Method string `json:"method"` } // defaultEndpointInfoMap 保存内置端点的默认 Path 与 Method var defaultEndpointInfoMap = map[constant.EndpointType]EndpointInfo{ - constant.EndpointTypeOpenAI: {Path: "/v1/chat/completions", Method: "POST"}, - constant.EndpointTypeOpenAIResponse: {Path: "/v1/responses", Method: "POST"}, - constant.EndpointTypeAnthropic: {Path: "/v1/messages", Method: "POST"}, - constant.EndpointTypeGemini: {Path: "/v1beta/models/{model}:generateContent", Method: "POST"}, - constant.EndpointTypeJinaRerank: {Path: "/rerank", Method: "POST"}, - constant.EndpointTypeImageGeneration: {Path: "/v1/images/generations", Method: "POST"}, + constant.EndpointTypeOpenAI: {Path: "/v1/chat/completions", Method: "POST"}, + constant.EndpointTypeOpenAIResponse: {Path: "/v1/responses", Method: "POST"}, + constant.EndpointTypeAnthropic: {Path: "/v1/messages", Method: "POST"}, + constant.EndpointTypeGemini: {Path: "/v1beta/models/{model}:generateContent", Method: "POST"}, + constant.EndpointTypeJinaRerank: {Path: "/rerank", Method: "POST"}, + constant.EndpointTypeImageGeneration: {Path: "/v1/images/generations", Method: "POST"}, } // GetDefaultEndpointInfo 返回指定端点类型的默认信息以及是否存在 func GetDefaultEndpointInfo(et constant.EndpointType) (EndpointInfo, bool) { - info, ok := defaultEndpointInfoMap[et] - return info, ok + info, ok := defaultEndpointInfoMap[et] + return info, ok } diff --git a/controller/missing_models.go b/controller/missing_models.go index a3409e29..425f9b25 100644 --- a/controller/missing_models.go +++ b/controller/missing_models.go @@ -1,27 +1,27 @@ package controller import ( - "net/http" - "one-api/model" + "net/http" + "one-api/model" - "github.com/gin-gonic/gin" + "github.com/gin-gonic/gin" ) // GetMissingModels returns the list of model names that are referenced by channels // but do not have corresponding records in the models meta table. // This helps administrators quickly discover models that need configuration. func GetMissingModels(c *gin.Context) { - missing, err := model.GetMissingModels() - if err != nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": err.Error(), - }) - return - } + missing, err := model.GetMissingModels() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } - c.JSON(http.StatusOK, gin.H{ - "success": true, - "data": missing, - }) + c.JSON(http.StatusOK, gin.H{ + "success": true, + "data": missing, + }) } diff --git a/controller/model_meta.go b/controller/model_meta.go index 090ea3c1..b097c80a 100644 --- a/controller/model_meta.go +++ b/controller/model_meta.go @@ -1,178 +1,178 @@ package controller import ( - "encoding/json" - "strconv" + "encoding/json" + "strconv" - "one-api/common" - "one-api/model" + "one-api/common" + "one-api/model" - "github.com/gin-gonic/gin" + "github.com/gin-gonic/gin" ) // GetAllModelsMeta 获取模型列表(分页) func GetAllModelsMeta(c *gin.Context) { - pageInfo := common.GetPageQuery(c) - modelsMeta, err := model.GetAllModels(pageInfo.GetStartIdx(), pageInfo.GetPageSize()) - if err != nil { - common.ApiError(c, err) - return - } - // 填充附加字段 - for _, m := range modelsMeta { - fillModelExtra(m) - } - var total int64 - model.DB.Model(&model.Model{}).Count(&total) + pageInfo := common.GetPageQuery(c) + modelsMeta, err := model.GetAllModels(pageInfo.GetStartIdx(), pageInfo.GetPageSize()) + if err != nil { + common.ApiError(c, err) + return + } + // 填充附加字段 + for _, m := range modelsMeta { + fillModelExtra(m) + } + var total int64 + model.DB.Model(&model.Model{}).Count(&total) - // 统计供应商计数(全部数据,不受分页影响) - vendorCounts, _ := model.GetVendorModelCounts() + // 统计供应商计数(全部数据,不受分页影响) + vendorCounts, _ := model.GetVendorModelCounts() - pageInfo.SetTotal(int(total)) - pageInfo.SetItems(modelsMeta) - common.ApiSuccess(c, gin.H{ - "items": modelsMeta, - "total": total, - "page": pageInfo.GetPage(), - "page_size": pageInfo.GetPageSize(), - "vendor_counts": vendorCounts, - }) + pageInfo.SetTotal(int(total)) + pageInfo.SetItems(modelsMeta) + common.ApiSuccess(c, gin.H{ + "items": modelsMeta, + "total": total, + "page": pageInfo.GetPage(), + "page_size": pageInfo.GetPageSize(), + "vendor_counts": vendorCounts, + }) } // SearchModelsMeta 搜索模型列表 func SearchModelsMeta(c *gin.Context) { - keyword := c.Query("keyword") - vendor := c.Query("vendor") - pageInfo := common.GetPageQuery(c) + keyword := c.Query("keyword") + vendor := c.Query("vendor") + pageInfo := common.GetPageQuery(c) - modelsMeta, total, err := model.SearchModels(keyword, vendor, pageInfo.GetStartIdx(), pageInfo.GetPageSize()) - if err != nil { - common.ApiError(c, err) - return - } - for _, m := range modelsMeta { - fillModelExtra(m) - } - pageInfo.SetTotal(int(total)) - pageInfo.SetItems(modelsMeta) - common.ApiSuccess(c, pageInfo) + modelsMeta, total, err := model.SearchModels(keyword, vendor, pageInfo.GetStartIdx(), pageInfo.GetPageSize()) + if err != nil { + common.ApiError(c, err) + return + } + for _, m := range modelsMeta { + fillModelExtra(m) + } + pageInfo.SetTotal(int(total)) + pageInfo.SetItems(modelsMeta) + common.ApiSuccess(c, pageInfo) } // GetModelMeta 根据 ID 获取单条模型信息 func GetModelMeta(c *gin.Context) { - idStr := c.Param("id") - id, err := strconv.Atoi(idStr) - if err != nil { - common.ApiError(c, err) - return - } - var m model.Model - if err := model.DB.First(&m, id).Error; err != nil { - common.ApiError(c, err) - return - } - fillModelExtra(&m) - common.ApiSuccess(c, &m) + idStr := c.Param("id") + id, err := strconv.Atoi(idStr) + if err != nil { + common.ApiError(c, err) + return + } + var m model.Model + if err := model.DB.First(&m, id).Error; err != nil { + common.ApiError(c, err) + return + } + fillModelExtra(&m) + common.ApiSuccess(c, &m) } // CreateModelMeta 新建模型 func CreateModelMeta(c *gin.Context) { - var m model.Model - if err := c.ShouldBindJSON(&m); err != nil { - common.ApiError(c, err) - return - } - if m.ModelName == "" { - common.ApiErrorMsg(c, "模型名称不能为空") - return - } - // 名称冲突检查 - if dup, err := model.IsModelNameDuplicated(0, m.ModelName); err != nil { - common.ApiError(c, err) - return - } else if dup { - common.ApiErrorMsg(c, "模型名称已存在") - return - } + var m model.Model + if err := c.ShouldBindJSON(&m); err != nil { + common.ApiError(c, err) + return + } + if m.ModelName == "" { + common.ApiErrorMsg(c, "模型名称不能为空") + return + } + // 名称冲突检查 + if dup, err := model.IsModelNameDuplicated(0, m.ModelName); err != nil { + common.ApiError(c, err) + return + } else if dup { + common.ApiErrorMsg(c, "模型名称已存在") + return + } - if err := m.Insert(); err != nil { - common.ApiError(c, err) - return - } - model.RefreshPricing() - common.ApiSuccess(c, &m) + if err := m.Insert(); err != nil { + common.ApiError(c, err) + return + } + model.RefreshPricing() + common.ApiSuccess(c, &m) } // UpdateModelMeta 更新模型 func UpdateModelMeta(c *gin.Context) { - statusOnly := c.Query("status_only") == "true" + statusOnly := c.Query("status_only") == "true" - var m model.Model - if err := c.ShouldBindJSON(&m); err != nil { - common.ApiError(c, err) - return - } - if m.Id == 0 { - common.ApiErrorMsg(c, "缺少模型 ID") - return - } + var m model.Model + if err := c.ShouldBindJSON(&m); err != nil { + common.ApiError(c, err) + return + } + if m.Id == 0 { + common.ApiErrorMsg(c, "缺少模型 ID") + return + } - if statusOnly { - // 只更新状态,防止误清空其他字段 - if err := model.DB.Model(&model.Model{}).Where("id = ?", m.Id).Update("status", m.Status).Error; err != nil { - common.ApiError(c, err) - return - } - } else { - // 名称冲突检查 - if dup, err := model.IsModelNameDuplicated(m.Id, m.ModelName); err != nil { - common.ApiError(c, err) - return - } else if dup { - common.ApiErrorMsg(c, "模型名称已存在") - return - } + if statusOnly { + // 只更新状态,防止误清空其他字段 + if err := model.DB.Model(&model.Model{}).Where("id = ?", m.Id).Update("status", m.Status).Error; err != nil { + common.ApiError(c, err) + return + } + } else { + // 名称冲突检查 + if dup, err := model.IsModelNameDuplicated(m.Id, m.ModelName); err != nil { + common.ApiError(c, err) + return + } else if dup { + common.ApiErrorMsg(c, "模型名称已存在") + return + } - if err := m.Update(); err != nil { - common.ApiError(c, err) - return - } - } - model.RefreshPricing() - common.ApiSuccess(c, &m) + if err := m.Update(); err != nil { + common.ApiError(c, err) + return + } + } + model.RefreshPricing() + common.ApiSuccess(c, &m) } // DeleteModelMeta 删除模型 func DeleteModelMeta(c *gin.Context) { - idStr := c.Param("id") - id, err := strconv.Atoi(idStr) - if err != nil { - common.ApiError(c, err) - return - } - if err := model.DB.Delete(&model.Model{}, id).Error; err != nil { - common.ApiError(c, err) - return - } - model.RefreshPricing() - common.ApiSuccess(c, nil) + idStr := c.Param("id") + id, err := strconv.Atoi(idStr) + if err != nil { + common.ApiError(c, err) + return + } + if err := model.DB.Delete(&model.Model{}, id).Error; err != nil { + common.ApiError(c, err) + return + } + model.RefreshPricing() + common.ApiSuccess(c, nil) } // 辅助函数:填充 Endpoints 和 BoundChannels 和 EnableGroups func fillModelExtra(m *model.Model) { - if m.Endpoints == "" { - eps := model.GetModelSupportEndpointTypes(m.ModelName) - if b, err := json.Marshal(eps); err == nil { - m.Endpoints = string(b) - } - } - if channels, err := model.GetBoundChannels(m.ModelName); err == nil { - m.BoundChannels = channels - } - // 填充启用分组 - m.EnableGroups = model.GetModelEnableGroups(m.ModelName) - // 填充计费类型 - m.QuotaType = model.GetModelQuotaType(m.ModelName) + if m.Endpoints == "" { + eps := model.GetModelSupportEndpointTypes(m.ModelName) + if b, err := json.Marshal(eps); err == nil { + m.Endpoints = string(b) + } + } + if channels, err := model.GetBoundChannels(m.ModelName); err == nil { + m.BoundChannels = channels + } + // 填充启用分组 + m.EnableGroups = model.GetModelEnableGroups(m.ModelName) + // 填充计费类型 + m.QuotaType = model.GetModelQuotaType(m.ModelName) } diff --git a/controller/prefill_group.go b/controller/prefill_group.go index 4e29379b..d912d609 100644 --- a/controller/prefill_group.go +++ b/controller/prefill_group.go @@ -1,90 +1,90 @@ package controller import ( - "strconv" + "strconv" - "one-api/common" - "one-api/model" + "one-api/common" + "one-api/model" - "github.com/gin-gonic/gin" + "github.com/gin-gonic/gin" ) // GetPrefillGroups 获取预填组列表,可通过 ?type=xxx 过滤 func GetPrefillGroups(c *gin.Context) { - groupType := c.Query("type") - groups, err := model.GetAllPrefillGroups(groupType) - if err != nil { - common.ApiError(c, err) - return - } - common.ApiSuccess(c, groups) + groupType := c.Query("type") + groups, err := model.GetAllPrefillGroups(groupType) + if err != nil { + common.ApiError(c, err) + return + } + common.ApiSuccess(c, groups) } // CreatePrefillGroup 创建新的预填组 func CreatePrefillGroup(c *gin.Context) { - var g model.PrefillGroup - if err := c.ShouldBindJSON(&g); err != nil { - common.ApiError(c, err) - return - } - if g.Name == "" || g.Type == "" { - common.ApiErrorMsg(c, "组名称和类型不能为空") - return - } - // 创建前检查名称 - if dup, err := model.IsPrefillGroupNameDuplicated(0, g.Name); err != nil { - common.ApiError(c, err) - return - } else if dup { - common.ApiErrorMsg(c, "组名称已存在") - return - } + var g model.PrefillGroup + if err := c.ShouldBindJSON(&g); err != nil { + common.ApiError(c, err) + return + } + if g.Name == "" || g.Type == "" { + common.ApiErrorMsg(c, "组名称和类型不能为空") + return + } + // 创建前检查名称 + if dup, err := model.IsPrefillGroupNameDuplicated(0, g.Name); err != nil { + common.ApiError(c, err) + return + } else if dup { + common.ApiErrorMsg(c, "组名称已存在") + return + } - if err := g.Insert(); err != nil { - common.ApiError(c, err) - return - } - common.ApiSuccess(c, &g) + if err := g.Insert(); err != nil { + common.ApiError(c, err) + return + } + common.ApiSuccess(c, &g) } // UpdatePrefillGroup 更新预填组 func UpdatePrefillGroup(c *gin.Context) { - var g model.PrefillGroup - if err := c.ShouldBindJSON(&g); err != nil { - common.ApiError(c, err) - return - } - if g.Id == 0 { - common.ApiErrorMsg(c, "缺少组 ID") - return - } - // 名称冲突检查 - if dup, err := model.IsPrefillGroupNameDuplicated(g.Id, g.Name); err != nil { - common.ApiError(c, err) - return - } else if dup { - common.ApiErrorMsg(c, "组名称已存在") - return - } + var g model.PrefillGroup + if err := c.ShouldBindJSON(&g); err != nil { + common.ApiError(c, err) + return + } + if g.Id == 0 { + common.ApiErrorMsg(c, "缺少组 ID") + return + } + // 名称冲突检查 + if dup, err := model.IsPrefillGroupNameDuplicated(g.Id, g.Name); err != nil { + common.ApiError(c, err) + return + } else if dup { + common.ApiErrorMsg(c, "组名称已存在") + return + } - if err := g.Update(); err != nil { - common.ApiError(c, err) - return - } - common.ApiSuccess(c, &g) + if err := g.Update(); err != nil { + common.ApiError(c, err) + return + } + common.ApiSuccess(c, &g) } // DeletePrefillGroup 删除预填组 func DeletePrefillGroup(c *gin.Context) { - idStr := c.Param("id") - id, err := strconv.Atoi(idStr) - if err != nil { - common.ApiError(c, err) - return - } - if err := model.DeletePrefillGroupByID(id); err != nil { - common.ApiError(c, err) - return - } - common.ApiSuccess(c, nil) + idStr := c.Param("id") + id, err := strconv.Atoi(idStr) + if err != nil { + common.ApiError(c, err) + return + } + if err := model.DeletePrefillGroupByID(id); err != nil { + common.ApiError(c, err) + return + } + common.ApiSuccess(c, nil) } diff --git a/controller/pricing.go b/controller/pricing.go index 898c9f9f..4b7cc86d 100644 --- a/controller/pricing.go +++ b/controller/pricing.go @@ -39,14 +39,14 @@ func GetPricing(c *gin.Context) { } c.JSON(200, gin.H{ - "success": true, - "data": pricing, - "vendors": model.GetVendors(), - "group_ratio": groupRatio, - "usable_group": usableGroup, - "supported_endpoint": model.GetSupportedEndpointMap(), - "auto_groups": setting.AutoGroups, - }) + "success": true, + "data": pricing, + "vendors": model.GetVendors(), + "group_ratio": groupRatio, + "usable_group": usableGroup, + "supported_endpoint": model.GetSupportedEndpointMap(), + "auto_groups": setting.AutoGroups, + }) } func ResetModelRatio(c *gin.Context) { diff --git a/controller/user.go b/controller/user.go index 6e968037..29cf83e1 100644 --- a/controller/user.go +++ b/controller/user.go @@ -62,7 +62,7 @@ func Login(c *gin.Context) { }) return } - + // 检查是否启用2FA if model.IsTwoFAEnabled(user.Id) { // 设置pending session,等待2FA验证 @@ -77,7 +77,7 @@ func Login(c *gin.Context) { }) return } - + c.JSON(http.StatusOK, gin.H{ "message": "请输入两步验证码", "success": true, @@ -87,7 +87,7 @@ func Login(c *gin.Context) { }) return } - + setupLogin(&user, c) } diff --git a/controller/vendor_meta.go b/controller/vendor_meta.go index 28664dd6..21d5a21d 100644 --- a/controller/vendor_meta.go +++ b/controller/vendor_meta.go @@ -1,124 +1,124 @@ package controller import ( - "strconv" + "strconv" - "one-api/common" - "one-api/model" + "one-api/common" + "one-api/model" - "github.com/gin-gonic/gin" + "github.com/gin-gonic/gin" ) // GetAllVendors 获取供应商列表(分页) func GetAllVendors(c *gin.Context) { - pageInfo := common.GetPageQuery(c) - vendors, err := model.GetAllVendors(pageInfo.GetStartIdx(), pageInfo.GetPageSize()) - if err != nil { - common.ApiError(c, err) - return - } - var total int64 - model.DB.Model(&model.Vendor{}).Count(&total) - pageInfo.SetTotal(int(total)) - pageInfo.SetItems(vendors) - common.ApiSuccess(c, pageInfo) + pageInfo := common.GetPageQuery(c) + vendors, err := model.GetAllVendors(pageInfo.GetStartIdx(), pageInfo.GetPageSize()) + if err != nil { + common.ApiError(c, err) + return + } + var total int64 + model.DB.Model(&model.Vendor{}).Count(&total) + pageInfo.SetTotal(int(total)) + pageInfo.SetItems(vendors) + common.ApiSuccess(c, pageInfo) } // SearchVendors 搜索供应商 func SearchVendors(c *gin.Context) { - keyword := c.Query("keyword") - pageInfo := common.GetPageQuery(c) - vendors, total, err := model.SearchVendors(keyword, pageInfo.GetStartIdx(), pageInfo.GetPageSize()) - if err != nil { - common.ApiError(c, err) - return - } - pageInfo.SetTotal(int(total)) - pageInfo.SetItems(vendors) - common.ApiSuccess(c, pageInfo) + keyword := c.Query("keyword") + pageInfo := common.GetPageQuery(c) + vendors, total, err := model.SearchVendors(keyword, pageInfo.GetStartIdx(), pageInfo.GetPageSize()) + if err != nil { + common.ApiError(c, err) + return + } + pageInfo.SetTotal(int(total)) + pageInfo.SetItems(vendors) + common.ApiSuccess(c, pageInfo) } // GetVendorMeta 根据 ID 获取供应商 func GetVendorMeta(c *gin.Context) { - idStr := c.Param("id") - id, err := strconv.Atoi(idStr) - if err != nil { - common.ApiError(c, err) - return - } - v, err := model.GetVendorByID(id) - if err != nil { - common.ApiError(c, err) - return - } - common.ApiSuccess(c, v) + idStr := c.Param("id") + id, err := strconv.Atoi(idStr) + if err != nil { + common.ApiError(c, err) + return + } + v, err := model.GetVendorByID(id) + if err != nil { + common.ApiError(c, err) + return + } + common.ApiSuccess(c, v) } // CreateVendorMeta 新建供应商 func CreateVendorMeta(c *gin.Context) { - var v model.Vendor - if err := c.ShouldBindJSON(&v); err != nil { - common.ApiError(c, err) - return - } - if v.Name == "" { - common.ApiErrorMsg(c, "供应商名称不能为空") - return - } - // 创建前先检查名称 - if dup, err := model.IsVendorNameDuplicated(0, v.Name); err != nil { - common.ApiError(c, err) - return - } else if dup { - common.ApiErrorMsg(c, "供应商名称已存在") - return - } + var v model.Vendor + if err := c.ShouldBindJSON(&v); err != nil { + common.ApiError(c, err) + return + } + if v.Name == "" { + common.ApiErrorMsg(c, "供应商名称不能为空") + return + } + // 创建前先检查名称 + if dup, err := model.IsVendorNameDuplicated(0, v.Name); err != nil { + common.ApiError(c, err) + return + } else if dup { + common.ApiErrorMsg(c, "供应商名称已存在") + return + } - if err := v.Insert(); err != nil { - common.ApiError(c, err) - return - } - common.ApiSuccess(c, &v) + if err := v.Insert(); err != nil { + common.ApiError(c, err) + return + } + common.ApiSuccess(c, &v) } // UpdateVendorMeta 更新供应商 func UpdateVendorMeta(c *gin.Context) { - var v model.Vendor - if err := c.ShouldBindJSON(&v); err != nil { - common.ApiError(c, err) - return - } - if v.Id == 0 { - common.ApiErrorMsg(c, "缺少供应商 ID") - return - } - // 名称冲突检查 - if dup, err := model.IsVendorNameDuplicated(v.Id, v.Name); err != nil { - common.ApiError(c, err) - return - } else if dup { - common.ApiErrorMsg(c, "供应商名称已存在") - return - } + var v model.Vendor + if err := c.ShouldBindJSON(&v); err != nil { + common.ApiError(c, err) + return + } + if v.Id == 0 { + common.ApiErrorMsg(c, "缺少供应商 ID") + return + } + // 名称冲突检查 + if dup, err := model.IsVendorNameDuplicated(v.Id, v.Name); err != nil { + common.ApiError(c, err) + return + } else if dup { + common.ApiErrorMsg(c, "供应商名称已存在") + return + } - if err := v.Update(); err != nil { - common.ApiError(c, err) - return - } - common.ApiSuccess(c, &v) + if err := v.Update(); err != nil { + common.ApiError(c, err) + return + } + common.ApiSuccess(c, &v) } // DeleteVendorMeta 删除供应商 func DeleteVendorMeta(c *gin.Context) { - idStr := c.Param("id") - id, err := strconv.Atoi(idStr) - if err != nil { - common.ApiError(c, err) - return - } - if err := model.DB.Delete(&model.Vendor{}, id).Error; err != nil { - common.ApiError(c, err) - return - } - common.ApiSuccess(c, nil) -} \ No newline at end of file + idStr := c.Param("id") + id, err := strconv.Atoi(idStr) + if err != nil { + common.ApiError(c, err) + return + } + if err := model.DB.Delete(&model.Vendor{}, id).Error; err != nil { + common.ApiError(c, err) + return + } + common.ApiSuccess(c, nil) +} diff --git a/model/missing_models.go b/model/missing_models.go index 57269f5f..18191ba6 100644 --- a/model/missing_models.go +++ b/model/missing_models.go @@ -2,29 +2,29 @@ package model // GetMissingModels returns model names that are referenced in the system func GetMissingModels() ([]string, error) { - // 1. 获取所有已启用模型(去重) - models := GetEnabledModels() - if len(models) == 0 { - return []string{}, nil - } + // 1. 获取所有已启用模型(去重) + models := GetEnabledModels() + if len(models) == 0 { + return []string{}, nil + } - // 2. 查询已有的元数据模型名 - var existing []string - if err := DB.Model(&Model{}).Where("model_name IN ?", models).Pluck("model_name", &existing).Error; err != nil { - return nil, err - } + // 2. 查询已有的元数据模型名 + var existing []string + if err := DB.Model(&Model{}).Where("model_name IN ?", models).Pluck("model_name", &existing).Error; err != nil { + return nil, err + } - existingSet := make(map[string]struct{}, len(existing)) - for _, e := range existing { - existingSet[e] = struct{}{} - } + existingSet := make(map[string]struct{}, len(existing)) + for _, e := range existing { + existingSet[e] = struct{}{} + } - // 3. 收集缺失模型 - var missing []string - for _, name := range models { - if _, ok := existingSet[name]; !ok { - missing = append(missing, name) - } - } - return missing, nil + // 3. 收集缺失模型 + var missing []string + for _, name := range models { + if _, ok := existingSet[name]; !ok { + missing = append(missing, name) + } + } + return missing, nil } diff --git a/model/model_extra.go b/model/model_extra.go index 6ade6ff0..495b5bb9 100644 --- a/model/model_extra.go +++ b/model/model_extra.go @@ -3,32 +3,32 @@ package model // GetModelEnableGroups 返回指定模型名称可用的用户分组列表。 // 使用在 updatePricing() 中维护的缓存映射,O(1) 读取,适合高并发场景。 func GetModelEnableGroups(modelName string) []string { - // 确保缓存最新 - GetPricing() + // 确保缓存最新 + GetPricing() - if modelName == "" { - return make([]string, 0) - } + if modelName == "" { + return make([]string, 0) + } - modelEnableGroupsLock.RLock() - groups, ok := modelEnableGroups[modelName] - modelEnableGroupsLock.RUnlock() - if !ok { - return make([]string, 0) - } - return groups + modelEnableGroupsLock.RLock() + groups, ok := modelEnableGroups[modelName] + modelEnableGroupsLock.RUnlock() + if !ok { + return make([]string, 0) + } + return groups } // GetModelQuotaType 返回指定模型的计费类型(quota_type)。 // 同样使用缓存映射,避免每次遍历定价切片。 func GetModelQuotaType(modelName string) int { - GetPricing() + GetPricing() - modelEnableGroupsLock.RLock() - quota, ok := modelQuotaTypeMap[modelName] - modelEnableGroupsLock.RUnlock() - if !ok { - return 0 - } - return quota + modelEnableGroupsLock.RLock() + quota, ok := modelQuotaTypeMap[modelName] + modelEnableGroupsLock.RUnlock() + if !ok { + return 0 + } + return quota } diff --git a/model/model_meta.go b/model/model_meta.go index 371e3137..b69608f1 100644 --- a/model/model_meta.go +++ b/model/model_meta.go @@ -1,11 +1,11 @@ package model import ( - "one-api/common" - "strconv" - "strings" + "one-api/common" + "strconv" + "strings" - "gorm.io/gorm" + "gorm.io/gorm" ) // Model 用于存储模型的元数据,例如描述、标签等 @@ -23,186 +23,186 @@ import ( // 模型名称匹配规则 const ( - NameRuleExact = iota // 0 精确匹配 - NameRulePrefix // 1 前缀匹配 - NameRuleContains // 2 包含匹配 - NameRuleSuffix // 3 后缀匹配 + NameRuleExact = iota // 0 精确匹配 + NameRulePrefix // 1 前缀匹配 + NameRuleContains // 2 包含匹配 + NameRuleSuffix // 3 后缀匹配 ) type BoundChannel struct { - Name string `json:"name"` - Type int `json:"type"` + Name string `json:"name"` + Type int `json:"type"` } type Model struct { - Id int `json:"id"` - ModelName string `json:"model_name" gorm:"size:128;not null;uniqueIndex:uk_model_name,priority:1"` - Description string `json:"description,omitempty" gorm:"type:text"` - Icon string `json:"icon,omitempty" gorm:"type:varchar(128)"` - Tags string `json:"tags,omitempty" gorm:"type:varchar(255)"` - VendorID int `json:"vendor_id,omitempty" gorm:"index"` - Endpoints string `json:"endpoints,omitempty" gorm:"type:text"` - Status int `json:"status" gorm:"default:1"` - CreatedTime int64 `json:"created_time" gorm:"bigint"` - UpdatedTime int64 `json:"updated_time" gorm:"bigint"` - DeletedAt gorm.DeletedAt `json:"-" gorm:"index;uniqueIndex:uk_model_name,priority:2"` + Id int `json:"id"` + ModelName string `json:"model_name" gorm:"size:128;not null;uniqueIndex:uk_model_name,priority:1"` + Description string `json:"description,omitempty" gorm:"type:text"` + Icon string `json:"icon,omitempty" gorm:"type:varchar(128)"` + Tags string `json:"tags,omitempty" gorm:"type:varchar(255)"` + VendorID int `json:"vendor_id,omitempty" gorm:"index"` + Endpoints string `json:"endpoints,omitempty" gorm:"type:text"` + Status int `json:"status" gorm:"default:1"` + CreatedTime int64 `json:"created_time" gorm:"bigint"` + UpdatedTime int64 `json:"updated_time" gorm:"bigint"` + DeletedAt gorm.DeletedAt `json:"-" gorm:"index;uniqueIndex:uk_model_name,priority:2"` - BoundChannels []BoundChannel `json:"bound_channels,omitempty" gorm:"-"` - EnableGroups []string `json:"enable_groups,omitempty" gorm:"-"` - QuotaType int `json:"quota_type" gorm:"-"` - NameRule int `json:"name_rule" gorm:"default:0"` + BoundChannels []BoundChannel `json:"bound_channels,omitempty" gorm:"-"` + EnableGroups []string `json:"enable_groups,omitempty" gorm:"-"` + QuotaType int `json:"quota_type" gorm:"-"` + NameRule int `json:"name_rule" gorm:"default:0"` } // Insert 创建新的模型元数据记录 func (mi *Model) Insert() error { - now := common.GetTimestamp() - mi.CreatedTime = now - mi.UpdatedTime = now - return DB.Create(mi).Error + now := common.GetTimestamp() + mi.CreatedTime = now + mi.UpdatedTime = now + return DB.Create(mi).Error } // IsModelNameDuplicated 检查模型名称是否重复(排除自身 ID) func IsModelNameDuplicated(id int, name string) (bool, error) { - if name == "" { - return false, nil - } - var cnt int64 - err := DB.Model(&Model{}).Where("model_name = ? AND id <> ?", name, id).Count(&cnt).Error - return cnt > 0, err + if name == "" { + return false, nil + } + var cnt int64 + err := DB.Model(&Model{}).Where("model_name = ? AND id <> ?", name, id).Count(&cnt).Error + return cnt > 0, err } // Update 更新现有模型记录 func (mi *Model) Update() error { - mi.UpdatedTime = common.GetTimestamp() - // 使用 Session 配置并选择所有字段,允许零值(如空字符串)也能被更新 - return DB.Session(&gorm.Session{AllowGlobalUpdate: false, FullSaveAssociations: false}). - Model(&Model{}). - Where("id = ?", mi.Id). - Omit("created_time"). - Select("*"). - Updates(mi).Error + mi.UpdatedTime = common.GetTimestamp() + // 使用 Session 配置并选择所有字段,允许零值(如空字符串)也能被更新 + return DB.Session(&gorm.Session{AllowGlobalUpdate: false, FullSaveAssociations: false}). + Model(&Model{}). + Where("id = ?", mi.Id). + Omit("created_time"). + Select("*"). + Updates(mi).Error } // Delete 软删除模型记录 func (mi *Model) Delete() error { - return DB.Delete(mi).Error + return DB.Delete(mi).Error } // GetModelByName 根据模型名称查询元数据 func GetModelByName(name string) (*Model, error) { - var mi Model - err := DB.Where("model_name = ?", name).First(&mi).Error - if err != nil { - return nil, err - } - return &mi, nil + var mi Model + err := DB.Where("model_name = ?", name).First(&mi).Error + if err != nil { + return nil, err + } + return &mi, nil } // GetVendorModelCounts 统计每个供应商下模型数量(不受分页影响) func GetVendorModelCounts() (map[int64]int64, error) { - var stats []struct { - VendorID int64 - Count int64 - } - if err := DB.Model(&Model{}). - Select("vendor_id as vendor_id, count(*) as count"). - Group("vendor_id"). - Scan(&stats).Error; err != nil { - return nil, err - } - m := make(map[int64]int64, len(stats)) - for _, s := range stats { - m[s.VendorID] = s.Count - } - return m, nil + var stats []struct { + VendorID int64 + Count int64 + } + if err := DB.Model(&Model{}). + Select("vendor_id as vendor_id, count(*) as count"). + Group("vendor_id"). + Scan(&stats).Error; err != nil { + return nil, err + } + m := make(map[int64]int64, len(stats)) + for _, s := range stats { + m[s.VendorID] = s.Count + } + return m, nil } // GetAllModels 分页获取所有模型元数据 func GetAllModels(offset int, limit int) ([]*Model, error) { - var models []*Model - err := DB.Offset(offset).Limit(limit).Find(&models).Error - return models, err + var models []*Model + err := DB.Offset(offset).Limit(limit).Find(&models).Error + return models, err } // GetBoundChannels 查询支持该模型的渠道(名称+类型) func GetBoundChannels(modelName string) ([]BoundChannel, error) { - var channels []BoundChannel - err := DB.Table("channels"). - Select("channels.name, channels.type"). - Joins("join abilities on abilities.channel_id = channels.id"). - Where("abilities.model = ? AND abilities.enabled = ?", modelName, true). - Group("channels.id"). - Scan(&channels).Error - return channels, err + var channels []BoundChannel + err := DB.Table("channels"). + Select("channels.name, channels.type"). + Joins("join abilities on abilities.channel_id = channels.id"). + Where("abilities.model = ? AND abilities.enabled = ?", modelName, true). + Group("channels.id"). + Scan(&channels).Error + return channels, err } // FindModelByNameWithRule 根据模型名称和匹配规则查找模型元数据,优先级:精确 > 前缀 > 后缀 > 包含 func FindModelByNameWithRule(name string) (*Model, error) { - // 1. 精确匹配 - if m, err := GetModelByName(name); err == nil { - return m, nil - } - // 2. 规则匹配 - var models []*Model - if err := DB.Where("name_rule <> ?", NameRuleExact).Find(&models).Error; err != nil { - return nil, err - } - var prefixMatch, suffixMatch, containsMatch *Model - for _, m := range models { - switch m.NameRule { - case NameRulePrefix: - if strings.HasPrefix(name, m.ModelName) { - if prefixMatch == nil || len(m.ModelName) > len(prefixMatch.ModelName) { - prefixMatch = m - } - } - case NameRuleSuffix: - if strings.HasSuffix(name, m.ModelName) { - if suffixMatch == nil || len(m.ModelName) > len(suffixMatch.ModelName) { - suffixMatch = m - } - } - case NameRuleContains: - if strings.Contains(name, m.ModelName) { - if containsMatch == nil || len(m.ModelName) > len(containsMatch.ModelName) { - containsMatch = m - } - } - } - } - if prefixMatch != nil { - return prefixMatch, nil - } - if suffixMatch != nil { - return suffixMatch, nil - } - if containsMatch != nil { - return containsMatch, nil - } - return nil, gorm.ErrRecordNotFound + // 1. 精确匹配 + if m, err := GetModelByName(name); err == nil { + return m, nil + } + // 2. 规则匹配 + var models []*Model + if err := DB.Where("name_rule <> ?", NameRuleExact).Find(&models).Error; err != nil { + return nil, err + } + var prefixMatch, suffixMatch, containsMatch *Model + for _, m := range models { + switch m.NameRule { + case NameRulePrefix: + if strings.HasPrefix(name, m.ModelName) { + if prefixMatch == nil || len(m.ModelName) > len(prefixMatch.ModelName) { + prefixMatch = m + } + } + case NameRuleSuffix: + if strings.HasSuffix(name, m.ModelName) { + if suffixMatch == nil || len(m.ModelName) > len(suffixMatch.ModelName) { + suffixMatch = m + } + } + case NameRuleContains: + if strings.Contains(name, m.ModelName) { + if containsMatch == nil || len(m.ModelName) > len(containsMatch.ModelName) { + containsMatch = m + } + } + } + } + if prefixMatch != nil { + return prefixMatch, nil + } + if suffixMatch != nil { + return suffixMatch, nil + } + if containsMatch != nil { + return containsMatch, nil + } + return nil, gorm.ErrRecordNotFound } // SearchModels 根据关键词和供应商搜索模型,支持分页 func SearchModels(keyword string, vendor string, offset int, limit int) ([]*Model, int64, error) { - var models []*Model - db := DB.Model(&Model{}) - if keyword != "" { - like := "%" + keyword + "%" - db = db.Where("model_name LIKE ? OR description LIKE ? OR tags LIKE ?", like, like, like) - } - if vendor != "" { - // 如果是数字,按供应商 ID 精确匹配;否则按名称模糊匹配 - if vid, err := strconv.Atoi(vendor); err == nil { - db = db.Where("models.vendor_id = ?", vid) - } else { - db = db.Joins("JOIN vendors ON vendors.id = models.vendor_id").Where("vendors.name LIKE ?", "%"+vendor+"%") - } - } - var total int64 - err := db.Count(&total).Error - if err != nil { - return nil, 0, err - } - err = db.Offset(offset).Limit(limit).Order("models.id DESC").Find(&models).Error - return models, total, err + var models []*Model + db := DB.Model(&Model{}) + if keyword != "" { + like := "%" + keyword + "%" + db = db.Where("model_name LIKE ? OR description LIKE ? OR tags LIKE ?", like, like, like) + } + if vendor != "" { + // 如果是数字,按供应商 ID 精确匹配;否则按名称模糊匹配 + if vid, err := strconv.Atoi(vendor); err == nil { + db = db.Where("models.vendor_id = ?", vid) + } else { + db = db.Joins("JOIN vendors ON vendors.id = models.vendor_id").Where("vendors.name LIKE ?", "%"+vendor+"%") + } + } + var total int64 + err := db.Count(&total).Error + if err != nil { + return nil, 0, err + } + err = db.Offset(offset).Limit(limit).Order("models.id DESC").Find(&models).Error + return models, total, err } diff --git a/model/prefill_group.go b/model/prefill_group.go index d9e92faa..a21b76fe 100644 --- a/model/prefill_group.go +++ b/model/prefill_group.go @@ -1,11 +1,11 @@ package model import ( - "encoding/json" - "database/sql/driver" - "one-api/common" + "database/sql/driver" + "encoding/json" + "one-api/common" - "gorm.io/gorm" + "gorm.io/gorm" ) // PrefillGroup 用于存储可复用的“组”信息,例如模型组、标签组、端点组等。 @@ -20,107 +20,107 @@ type JSONValue json.RawMessage // Value 实现 driver.Valuer 接口,用于数据库写入 func (j JSONValue) Value() (driver.Value, error) { - if j == nil { - return nil, nil - } - return []byte(j), nil + if j == nil { + return nil, nil + } + return []byte(j), nil } // Scan 实现 sql.Scanner 接口,兼容不同驱动返回的类型 func (j *JSONValue) Scan(value interface{}) error { - switch v := value.(type) { - case nil: - *j = nil - return nil - case []byte: - // 拷贝底层字节,避免保留底层缓冲区 - b := make([]byte, len(v)) - copy(b, v) - *j = JSONValue(b) - return nil - case string: - *j = JSONValue([]byte(v)) - return nil - default: - // 其他类型尝试序列化为 JSON - b, err := json.Marshal(v) - if err != nil { - return err - } - *j = JSONValue(b) - return nil - } + switch v := value.(type) { + case nil: + *j = nil + return nil + case []byte: + // 拷贝底层字节,避免保留底层缓冲区 + b := make([]byte, len(v)) + copy(b, v) + *j = JSONValue(b) + return nil + case string: + *j = JSONValue([]byte(v)) + return nil + default: + // 其他类型尝试序列化为 JSON + b, err := json.Marshal(v) + if err != nil { + return err + } + *j = JSONValue(b) + return nil + } } // MarshalJSON 确保在对外编码时与 json.RawMessage 行为一致 func (j JSONValue) MarshalJSON() ([]byte, error) { - if j == nil { - return []byte("null"), nil - } - return j, nil + if j == nil { + return []byte("null"), nil + } + return j, nil } // UnmarshalJSON 确保在对外解码时与 json.RawMessage 行为一致 func (j *JSONValue) UnmarshalJSON(data []byte) error { - if data == nil { - *j = nil - return nil - } - b := make([]byte, len(data)) - copy(b, data) - *j = JSONValue(b) - return nil + if data == nil { + *j = nil + return nil + } + b := make([]byte, len(data)) + copy(b, data) + *j = JSONValue(b) + return nil } type PrefillGroup struct { - Id int `json:"id"` - Name string `json:"name" gorm:"size:64;not null;uniqueIndex:uk_prefill_name,where:deleted_at IS NULL"` - Type string `json:"type" gorm:"size:32;index;not null"` - Items JSONValue `json:"items" gorm:"type:json"` - Description string `json:"description,omitempty" gorm:"type:varchar(255)"` - CreatedTime int64 `json:"created_time" gorm:"bigint"` - UpdatedTime int64 `json:"updated_time" gorm:"bigint"` - DeletedAt gorm.DeletedAt `json:"-" gorm:"index"` + Id int `json:"id"` + Name string `json:"name" gorm:"size:64;not null;uniqueIndex:uk_prefill_name,where:deleted_at IS NULL"` + Type string `json:"type" gorm:"size:32;index;not null"` + Items JSONValue `json:"items" gorm:"type:json"` + Description string `json:"description,omitempty" gorm:"type:varchar(255)"` + CreatedTime int64 `json:"created_time" gorm:"bigint"` + UpdatedTime int64 `json:"updated_time" gorm:"bigint"` + DeletedAt gorm.DeletedAt `json:"-" gorm:"index"` } // Insert 新建组 func (g *PrefillGroup) Insert() error { - now := common.GetTimestamp() - g.CreatedTime = now - g.UpdatedTime = now - return DB.Create(g).Error + now := common.GetTimestamp() + g.CreatedTime = now + g.UpdatedTime = now + return DB.Create(g).Error } // IsPrefillGroupNameDuplicated 检查组名称是否重复(排除自身 ID) func IsPrefillGroupNameDuplicated(id int, name string) (bool, error) { - if name == "" { - return false, nil - } - var cnt int64 - err := DB.Model(&PrefillGroup{}).Where("name = ? AND id <> ?", name, id).Count(&cnt).Error - return cnt > 0, err + if name == "" { + return false, nil + } + var cnt int64 + err := DB.Model(&PrefillGroup{}).Where("name = ? AND id <> ?", name, id).Count(&cnt).Error + return cnt > 0, err } // Update 更新组 func (g *PrefillGroup) Update() error { - g.UpdatedTime = common.GetTimestamp() - return DB.Save(g).Error + g.UpdatedTime = common.GetTimestamp() + return DB.Save(g).Error } // DeleteByID 根据 ID 删除组 func DeletePrefillGroupByID(id int) error { - return DB.Delete(&PrefillGroup{}, id).Error + return DB.Delete(&PrefillGroup{}, id).Error } // GetAllPrefillGroups 获取全部组,可按类型过滤(为空则返回全部) func GetAllPrefillGroups(groupType string) ([]*PrefillGroup, error) { - var groups []*PrefillGroup - query := DB.Model(&PrefillGroup{}) - if groupType != "" { - query = query.Where("type = ?", groupType) - } - if err := query.Order("updated_time DESC").Find(&groups).Error; err != nil { - return nil, err - } - return groups, nil + var groups []*PrefillGroup + query := DB.Model(&PrefillGroup{}) + if groupType != "" { + query = query.Where("type = ?", groupType) + } + if err := query.Order("updated_time DESC").Find(&groups).Error; err != nil { + return nil, err + } + return groups, nil } diff --git a/model/pricing.go b/model/pricing.go index c5fbff36..0936d298 100644 --- a/model/pricing.go +++ b/model/pricing.go @@ -1,31 +1,31 @@ package model import ( - "encoding/json" - "fmt" - "strings" + "encoding/json" + "fmt" + "strings" - "one-api/common" - "one-api/constant" - "one-api/setting/ratio_setting" - "one-api/types" - "sync" - "time" + "one-api/common" + "one-api/constant" + "one-api/setting/ratio_setting" + "one-api/types" + "sync" + "time" ) type Pricing struct { - ModelName string `json:"model_name"` - Description string `json:"description,omitempty"` - Icon string `json:"icon,omitempty"` - Tags string `json:"tags,omitempty"` - VendorID int `json:"vendor_id,omitempty"` - QuotaType int `json:"quota_type"` - ModelRatio float64 `json:"model_ratio"` - ModelPrice float64 `json:"model_price"` - OwnerBy string `json:"owner_by"` - CompletionRatio float64 `json:"completion_ratio"` - EnableGroup []string `json:"enable_groups"` - SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"` + ModelName string `json:"model_name"` + Description string `json:"description,omitempty"` + Icon string `json:"icon,omitempty"` + Tags string `json:"tags,omitempty"` + VendorID int `json:"vendor_id,omitempty"` + QuotaType int `json:"quota_type"` + ModelRatio float64 `json:"model_ratio"` + ModelPrice float64 `json:"model_price"` + OwnerBy string `json:"owner_by"` + CompletionRatio float64 `json:"completion_ratio"` + EnableGroup []string `json:"enable_groups"` + SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"` } type PricingVendor struct { @@ -36,11 +36,11 @@ type PricingVendor struct { } var ( - pricingMap []Pricing - vendorsList []PricingVendor - supportedEndpointMap map[string]common.EndpointInfo - lastGetPricingTime time.Time - updatePricingLock sync.Mutex + pricingMap []Pricing + vendorsList []PricingVendor + supportedEndpointMap map[string]common.EndpointInfo + lastGetPricingTime time.Time + updatePricingLock sync.Mutex // 缓存映射:模型名 -> 启用分组 / 计费类型 modelEnableGroups = make(map[string][]string) @@ -122,19 +122,19 @@ func updatePricing() { for _, m := range prefixList { for _, pricingModel := range enableAbilities { if strings.HasPrefix(pricingModel.Model, m.ModelName) { - if _, exists := metaMap[pricingModel.Model]; !exists { - metaMap[pricingModel.Model] = m - } - } + if _, exists := metaMap[pricingModel.Model]; !exists { + metaMap[pricingModel.Model] = m + } + } } } for _, m := range suffixList { for _, pricingModel := range enableAbilities { if strings.HasSuffix(pricingModel.Model, m.ModelName) { - if _, exists := metaMap[pricingModel.Model]; !exists { - metaMap[pricingModel.Model] = m - } - } + if _, exists := metaMap[pricingModel.Model]; !exists { + metaMap[pricingModel.Model] = m + } + } } } for _, m := range containsList { @@ -180,34 +180,34 @@ func updatePricing() { //这里使用切片而不是Set,因为一个模型可能支持多个端点类型,并且第一个端点是优先使用端点 modelSupportEndpointsStr := make(map[string][]string) - // 先根据已有能力填充原生端点 - for _, ability := range enableAbilities { - endpoints := modelSupportEndpointsStr[ability.Model] - channelTypes := common.GetEndpointTypesByChannelType(ability.ChannelType, ability.Model) - for _, channelType := range channelTypes { - if !common.StringsContains(endpoints, string(channelType)) { - endpoints = append(endpoints, string(channelType)) - } - } - modelSupportEndpointsStr[ability.Model] = endpoints - } + // 先根据已有能力填充原生端点 + for _, ability := range enableAbilities { + endpoints := modelSupportEndpointsStr[ability.Model] + channelTypes := common.GetEndpointTypesByChannelType(ability.ChannelType, ability.Model) + for _, channelType := range channelTypes { + if !common.StringsContains(endpoints, string(channelType)) { + endpoints = append(endpoints, string(channelType)) + } + } + modelSupportEndpointsStr[ability.Model] = endpoints + } - // 再补充模型自定义端点 - for modelName, meta := range metaMap { - if strings.TrimSpace(meta.Endpoints) == "" { - continue - } - var raw map[string]interface{} - if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil { - endpoints := modelSupportEndpointsStr[modelName] - for k := range raw { - if !common.StringsContains(endpoints, k) { - endpoints = append(endpoints, k) - } - } - modelSupportEndpointsStr[modelName] = endpoints - } - } + // 再补充模型自定义端点 + for modelName, meta := range metaMap { + if strings.TrimSpace(meta.Endpoints) == "" { + continue + } + var raw map[string]interface{} + if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil { + endpoints := modelSupportEndpointsStr[modelName] + for k := range raw { + if !common.StringsContains(endpoints, k) { + endpoints = append(endpoints, k) + } + } + modelSupportEndpointsStr[modelName] = endpoints + } + } modelSupportEndpointTypes = make(map[string][]constant.EndpointType) for model, endpoints := range modelSupportEndpointsStr { @@ -217,93 +217,93 @@ func updatePricing() { supportedEndpoints = append(supportedEndpoints, endpointType) } modelSupportEndpointTypes[model] = supportedEndpoints - } + } - // 构建全局 supportedEndpointMap(默认 + 自定义覆盖) - supportedEndpointMap = make(map[string]common.EndpointInfo) - // 1. 默认端点 - for _, endpoints := range modelSupportEndpointTypes { - for _, et := range endpoints { - if info, ok := common.GetDefaultEndpointInfo(et); ok { - if _, exists := supportedEndpointMap[string(et)]; !exists { - supportedEndpointMap[string(et)] = info - } - } - } - } - // 2. 自定义端点(models 表)覆盖默认 - for _, meta := range metaMap { - if strings.TrimSpace(meta.Endpoints) == "" { - continue - } - var raw map[string]interface{} - if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil { - for k, v := range raw { - switch val := v.(type) { - case string: - supportedEndpointMap[k] = common.EndpointInfo{Path: val, Method: "POST"} - case map[string]interface{}: - ep := common.EndpointInfo{Method: "POST"} - if p, ok := val["path"].(string); ok { - ep.Path = p - } - if m, ok := val["method"].(string); ok { - ep.Method = strings.ToUpper(m) - } - supportedEndpointMap[k] = ep - default: - // ignore unsupported types - } - } - } - } + // 构建全局 supportedEndpointMap(默认 + 自定义覆盖) + supportedEndpointMap = make(map[string]common.EndpointInfo) + // 1. 默认端点 + for _, endpoints := range modelSupportEndpointTypes { + for _, et := range endpoints { + if info, ok := common.GetDefaultEndpointInfo(et); ok { + if _, exists := supportedEndpointMap[string(et)]; !exists { + supportedEndpointMap[string(et)] = info + } + } + } + } + // 2. 自定义端点(models 表)覆盖默认 + for _, meta := range metaMap { + if strings.TrimSpace(meta.Endpoints) == "" { + continue + } + var raw map[string]interface{} + if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil { + for k, v := range raw { + switch val := v.(type) { + case string: + supportedEndpointMap[k] = common.EndpointInfo{Path: val, Method: "POST"} + case map[string]interface{}: + ep := common.EndpointInfo{Method: "POST"} + if p, ok := val["path"].(string); ok { + ep.Path = p + } + if m, ok := val["method"].(string); ok { + ep.Method = strings.ToUpper(m) + } + supportedEndpointMap[k] = ep + default: + // ignore unsupported types + } + } + } + } - pricingMap = make([]Pricing, 0) - for model, groups := range modelGroupsMap { - pricing := Pricing{ - ModelName: model, - EnableGroup: groups.Items(), - SupportedEndpointTypes: modelSupportEndpointTypes[model], - } + pricingMap = make([]Pricing, 0) + for model, groups := range modelGroupsMap { + pricing := Pricing{ + ModelName: model, + EnableGroup: groups.Items(), + SupportedEndpointTypes: modelSupportEndpointTypes[model], + } - // 补充模型元数据(描述、标签、供应商、状态) - if meta, ok := metaMap[model]; ok { - // 若模型被禁用(status!=1),则直接跳过,不返回给前端 - if meta.Status != 1 { - continue - } - pricing.Description = meta.Description - pricing.Icon = meta.Icon - pricing.Tags = meta.Tags - pricing.VendorID = meta.VendorID - } - modelPrice, findPrice := ratio_setting.GetModelPrice(model, false) - if findPrice { - pricing.ModelPrice = modelPrice - pricing.QuotaType = 1 - } else { - modelRatio, _, _ := ratio_setting.GetModelRatio(model) - pricing.ModelRatio = modelRatio - pricing.CompletionRatio = ratio_setting.GetCompletionRatio(model) - pricing.QuotaType = 0 - } - pricingMap = append(pricingMap, pricing) - } + // 补充模型元数据(描述、标签、供应商、状态) + if meta, ok := metaMap[model]; ok { + // 若模型被禁用(status!=1),则直接跳过,不返回给前端 + if meta.Status != 1 { + continue + } + pricing.Description = meta.Description + pricing.Icon = meta.Icon + pricing.Tags = meta.Tags + pricing.VendorID = meta.VendorID + } + modelPrice, findPrice := ratio_setting.GetModelPrice(model, false) + if findPrice { + pricing.ModelPrice = modelPrice + pricing.QuotaType = 1 + } else { + modelRatio, _, _ := ratio_setting.GetModelRatio(model) + pricing.ModelRatio = modelRatio + pricing.CompletionRatio = ratio_setting.GetCompletionRatio(model) + pricing.QuotaType = 0 + } + pricingMap = append(pricingMap, pricing) + } - // 刷新缓存映射,供高并发快速查询 - modelEnableGroupsLock.Lock() - modelEnableGroups = make(map[string][]string) - modelQuotaTypeMap = make(map[string]int) - for _, p := range pricingMap { - modelEnableGroups[p.ModelName] = p.EnableGroup - modelQuotaTypeMap[p.ModelName] = p.QuotaType - } - modelEnableGroupsLock.Unlock() + // 刷新缓存映射,供高并发快速查询 + modelEnableGroupsLock.Lock() + modelEnableGroups = make(map[string][]string) + modelQuotaTypeMap = make(map[string]int) + for _, p := range pricingMap { + modelEnableGroups[p.ModelName] = p.EnableGroup + modelQuotaTypeMap[p.ModelName] = p.QuotaType + } + modelEnableGroupsLock.Unlock() - lastGetPricingTime = time.Now() + lastGetPricingTime = time.Now() } // GetSupportedEndpointMap 返回全局端点到路径的映射 func GetSupportedEndpointMap() map[string]common.EndpointInfo { - return supportedEndpointMap + return supportedEndpointMap } diff --git a/model/pricing_refresh.go b/model/pricing_refresh.go index de72a8bb..cd0d7559 100644 --- a/model/pricing_refresh.go +++ b/model/pricing_refresh.go @@ -4,11 +4,11 @@ package model // 该方法用于需要最新数据的内部管理 API, // 因此会绕过默认的 1 分钟延迟刷新。 func RefreshPricing() { - updatePricingLock.Lock() - defer updatePricingLock.Unlock() + updatePricingLock.Lock() + defer updatePricingLock.Unlock() - modelSupportEndpointsLock.Lock() - defer modelSupportEndpointsLock.Unlock() + modelSupportEndpointsLock.Lock() + defer modelSupportEndpointsLock.Unlock() - updatePricing() + updatePricing() } diff --git a/model/vendor_meta.go b/model/vendor_meta.go index b96b1d5c..88439f24 100644 --- a/model/vendor_meta.go +++ b/model/vendor_meta.go @@ -1,9 +1,9 @@ package model import ( - "one-api/common" + "one-api/common" - "gorm.io/gorm" + "gorm.io/gorm" ) // Vendor 用于存储供应商信息,供模型引用 @@ -13,76 +13,76 @@ import ( // 本表同样遵循 3NF 设计范式 type Vendor struct { - Id int `json:"id"` - Name string `json:"name" gorm:"size:128;not null;uniqueIndex:uk_vendor_name,priority:1"` - Description string `json:"description,omitempty" gorm:"type:text"` - Icon string `json:"icon,omitempty" gorm:"type:varchar(128)"` - Status int `json:"status" gorm:"default:1"` - CreatedTime int64 `json:"created_time" gorm:"bigint"` - UpdatedTime int64 `json:"updated_time" gorm:"bigint"` - DeletedAt gorm.DeletedAt `json:"-" gorm:"index;uniqueIndex:uk_vendor_name,priority:2"` + Id int `json:"id"` + Name string `json:"name" gorm:"size:128;not null;uniqueIndex:uk_vendor_name,priority:1"` + Description string `json:"description,omitempty" gorm:"type:text"` + Icon string `json:"icon,omitempty" gorm:"type:varchar(128)"` + Status int `json:"status" gorm:"default:1"` + CreatedTime int64 `json:"created_time" gorm:"bigint"` + UpdatedTime int64 `json:"updated_time" gorm:"bigint"` + DeletedAt gorm.DeletedAt `json:"-" gorm:"index;uniqueIndex:uk_vendor_name,priority:2"` } // Insert 创建新的供应商记录 func (v *Vendor) Insert() error { - now := common.GetTimestamp() - v.CreatedTime = now - v.UpdatedTime = now - return DB.Create(v).Error + now := common.GetTimestamp() + v.CreatedTime = now + v.UpdatedTime = now + return DB.Create(v).Error } // IsVendorNameDuplicated 检查供应商名称是否重复(排除自身 ID) func IsVendorNameDuplicated(id int, name string) (bool, error) { - if name == "" { - return false, nil - } - var cnt int64 - err := DB.Model(&Vendor{}).Where("name = ? AND id <> ?", name, id).Count(&cnt).Error - return cnt > 0, err + if name == "" { + return false, nil + } + var cnt int64 + err := DB.Model(&Vendor{}).Where("name = ? AND id <> ?", name, id).Count(&cnt).Error + return cnt > 0, err } // Update 更新供应商记录 func (v *Vendor) Update() error { - v.UpdatedTime = common.GetTimestamp() - return DB.Save(v).Error + v.UpdatedTime = common.GetTimestamp() + return DB.Save(v).Error } // Delete 软删除供应商 func (v *Vendor) Delete() error { - return DB.Delete(v).Error + return DB.Delete(v).Error } // GetVendorByID 根据 ID 获取供应商 func GetVendorByID(id int) (*Vendor, error) { - var v Vendor - err := DB.First(&v, id).Error - if err != nil { - return nil, err - } - return &v, nil + var v Vendor + err := DB.First(&v, id).Error + if err != nil { + return nil, err + } + return &v, nil } // GetAllVendors 获取全部供应商(分页) func GetAllVendors(offset int, limit int) ([]*Vendor, error) { - var vendors []*Vendor - err := DB.Offset(offset).Limit(limit).Find(&vendors).Error - return vendors, err + var vendors []*Vendor + err := DB.Offset(offset).Limit(limit).Find(&vendors).Error + return vendors, err } // SearchVendors 按关键字搜索供应商 func SearchVendors(keyword string, offset int, limit int) ([]*Vendor, int64, error) { - db := DB.Model(&Vendor{}) - if keyword != "" { - like := "%" + keyword + "%" - db = db.Where("name LIKE ? OR description LIKE ?", like, like) - } - var total int64 - if err := db.Count(&total).Error; err != nil { - return nil, 0, err - } - var vendors []*Vendor - if err := db.Offset(offset).Limit(limit).Order("id DESC").Find(&vendors).Error; err != nil { - return nil, 0, err - } - return vendors, total, nil -} \ No newline at end of file + db := DB.Model(&Vendor{}) + if keyword != "" { + like := "%" + keyword + "%" + db = db.Where("name LIKE ? OR description LIKE ?", like, like) + } + var total int64 + if err := db.Count(&total).Error; err != nil { + return nil, 0, err + } + var vendors []*Vendor + if err := db.Offset(offset).Limit(limit).Order("id DESC").Find(&vendors).Error; err != nil { + return nil, 0, err + } + return vendors, total, nil +}