🖼️ chore: format code file
This commit is contained in:
@@ -11,22 +11,22 @@ import "one-api/constant"
|
|||||||
// 例如:{"path":"/v1/chat/completions","method":"POST"}
|
// 例如:{"path":"/v1/chat/completions","method":"POST"}
|
||||||
|
|
||||||
type EndpointInfo struct {
|
type EndpointInfo struct {
|
||||||
Path string `json:"path"`
|
Path string `json:"path"`
|
||||||
Method string `json:"method"`
|
Method string `json:"method"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// defaultEndpointInfoMap 保存内置端点的默认 Path 与 Method
|
// defaultEndpointInfoMap 保存内置端点的默认 Path 与 Method
|
||||||
var defaultEndpointInfoMap = map[constant.EndpointType]EndpointInfo{
|
var defaultEndpointInfoMap = map[constant.EndpointType]EndpointInfo{
|
||||||
constant.EndpointTypeOpenAI: {Path: "/v1/chat/completions", Method: "POST"},
|
constant.EndpointTypeOpenAI: {Path: "/v1/chat/completions", Method: "POST"},
|
||||||
constant.EndpointTypeOpenAIResponse: {Path: "/v1/responses", Method: "POST"},
|
constant.EndpointTypeOpenAIResponse: {Path: "/v1/responses", Method: "POST"},
|
||||||
constant.EndpointTypeAnthropic: {Path: "/v1/messages", Method: "POST"},
|
constant.EndpointTypeAnthropic: {Path: "/v1/messages", Method: "POST"},
|
||||||
constant.EndpointTypeGemini: {Path: "/v1beta/models/{model}:generateContent", Method: "POST"},
|
constant.EndpointTypeGemini: {Path: "/v1beta/models/{model}:generateContent", Method: "POST"},
|
||||||
constant.EndpointTypeJinaRerank: {Path: "/rerank", Method: "POST"},
|
constant.EndpointTypeJinaRerank: {Path: "/rerank", Method: "POST"},
|
||||||
constant.EndpointTypeImageGeneration: {Path: "/v1/images/generations", Method: "POST"},
|
constant.EndpointTypeImageGeneration: {Path: "/v1/images/generations", Method: "POST"},
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetDefaultEndpointInfo 返回指定端点类型的默认信息以及是否存在
|
// GetDefaultEndpointInfo 返回指定端点类型的默认信息以及是否存在
|
||||||
func GetDefaultEndpointInfo(et constant.EndpointType) (EndpointInfo, bool) {
|
func GetDefaultEndpointInfo(et constant.EndpointType) (EndpointInfo, bool) {
|
||||||
info, ok := defaultEndpointInfoMap[et]
|
info, ok := defaultEndpointInfoMap[et]
|
||||||
return info, ok
|
return info, ok
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,27 +1,27 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GetMissingModels returns the list of model names that are referenced by channels
|
// GetMissingModels returns the list of model names that are referenced by channels
|
||||||
// but do not have corresponding records in the models meta table.
|
// but do not have corresponding records in the models meta table.
|
||||||
// This helps administrators quickly discover models that need configuration.
|
// This helps administrators quickly discover models that need configuration.
|
||||||
func GetMissingModels(c *gin.Context) {
|
func GetMissingModels(c *gin.Context) {
|
||||||
missing, err := model.GetMissingModels()
|
missing, err := model.GetMissingModels()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": err.Error(),
|
"message": err.Error(),
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"data": missing,
|
"data": missing,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,178 +1,178 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GetAllModelsMeta 获取模型列表(分页)
|
// GetAllModelsMeta 获取模型列表(分页)
|
||||||
func GetAllModelsMeta(c *gin.Context) {
|
func GetAllModelsMeta(c *gin.Context) {
|
||||||
|
|
||||||
pageInfo := common.GetPageQuery(c)
|
pageInfo := common.GetPageQuery(c)
|
||||||
modelsMeta, err := model.GetAllModels(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
modelsMeta, err := model.GetAllModels(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.ApiError(c, err)
|
common.ApiError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// 填充附加字段
|
// 填充附加字段
|
||||||
for _, m := range modelsMeta {
|
for _, m := range modelsMeta {
|
||||||
fillModelExtra(m)
|
fillModelExtra(m)
|
||||||
}
|
}
|
||||||
var total int64
|
var total int64
|
||||||
model.DB.Model(&model.Model{}).Count(&total)
|
model.DB.Model(&model.Model{}).Count(&total)
|
||||||
|
|
||||||
// 统计供应商计数(全部数据,不受分页影响)
|
// 统计供应商计数(全部数据,不受分页影响)
|
||||||
vendorCounts, _ := model.GetVendorModelCounts()
|
vendorCounts, _ := model.GetVendorModelCounts()
|
||||||
|
|
||||||
pageInfo.SetTotal(int(total))
|
pageInfo.SetTotal(int(total))
|
||||||
pageInfo.SetItems(modelsMeta)
|
pageInfo.SetItems(modelsMeta)
|
||||||
common.ApiSuccess(c, gin.H{
|
common.ApiSuccess(c, gin.H{
|
||||||
"items": modelsMeta,
|
"items": modelsMeta,
|
||||||
"total": total,
|
"total": total,
|
||||||
"page": pageInfo.GetPage(),
|
"page": pageInfo.GetPage(),
|
||||||
"page_size": pageInfo.GetPageSize(),
|
"page_size": pageInfo.GetPageSize(),
|
||||||
"vendor_counts": vendorCounts,
|
"vendor_counts": vendorCounts,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// SearchModelsMeta 搜索模型列表
|
// SearchModelsMeta 搜索模型列表
|
||||||
func SearchModelsMeta(c *gin.Context) {
|
func SearchModelsMeta(c *gin.Context) {
|
||||||
|
|
||||||
keyword := c.Query("keyword")
|
keyword := c.Query("keyword")
|
||||||
vendor := c.Query("vendor")
|
vendor := c.Query("vendor")
|
||||||
pageInfo := common.GetPageQuery(c)
|
pageInfo := common.GetPageQuery(c)
|
||||||
|
|
||||||
modelsMeta, total, err := model.SearchModels(keyword, vendor, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
modelsMeta, total, err := model.SearchModels(keyword, vendor, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.ApiError(c, err)
|
common.ApiError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
for _, m := range modelsMeta {
|
for _, m := range modelsMeta {
|
||||||
fillModelExtra(m)
|
fillModelExtra(m)
|
||||||
}
|
}
|
||||||
pageInfo.SetTotal(int(total))
|
pageInfo.SetTotal(int(total))
|
||||||
pageInfo.SetItems(modelsMeta)
|
pageInfo.SetItems(modelsMeta)
|
||||||
common.ApiSuccess(c, pageInfo)
|
common.ApiSuccess(c, pageInfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetModelMeta 根据 ID 获取单条模型信息
|
// GetModelMeta 根据 ID 获取单条模型信息
|
||||||
func GetModelMeta(c *gin.Context) {
|
func GetModelMeta(c *gin.Context) {
|
||||||
idStr := c.Param("id")
|
idStr := c.Param("id")
|
||||||
id, err := strconv.Atoi(idStr)
|
id, err := strconv.Atoi(idStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.ApiError(c, err)
|
common.ApiError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
var m model.Model
|
var m model.Model
|
||||||
if err := model.DB.First(&m, id).Error; err != nil {
|
if err := model.DB.First(&m, id).Error; err != nil {
|
||||||
common.ApiError(c, err)
|
common.ApiError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
fillModelExtra(&m)
|
fillModelExtra(&m)
|
||||||
common.ApiSuccess(c, &m)
|
common.ApiSuccess(c, &m)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateModelMeta 新建模型
|
// CreateModelMeta 新建模型
|
||||||
func CreateModelMeta(c *gin.Context) {
|
func CreateModelMeta(c *gin.Context) {
|
||||||
var m model.Model
|
var m model.Model
|
||||||
if err := c.ShouldBindJSON(&m); err != nil {
|
if err := c.ShouldBindJSON(&m); err != nil {
|
||||||
common.ApiError(c, err)
|
common.ApiError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if m.ModelName == "" {
|
if m.ModelName == "" {
|
||||||
common.ApiErrorMsg(c, "模型名称不能为空")
|
common.ApiErrorMsg(c, "模型名称不能为空")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// 名称冲突检查
|
// 名称冲突检查
|
||||||
if dup, err := model.IsModelNameDuplicated(0, m.ModelName); err != nil {
|
if dup, err := model.IsModelNameDuplicated(0, m.ModelName); err != nil {
|
||||||
common.ApiError(c, err)
|
common.ApiError(c, err)
|
||||||
return
|
return
|
||||||
} else if dup {
|
} else if dup {
|
||||||
common.ApiErrorMsg(c, "模型名称已存在")
|
common.ApiErrorMsg(c, "模型名称已存在")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.Insert(); err != nil {
|
if err := m.Insert(); err != nil {
|
||||||
common.ApiError(c, err)
|
common.ApiError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
model.RefreshPricing()
|
model.RefreshPricing()
|
||||||
common.ApiSuccess(c, &m)
|
common.ApiSuccess(c, &m)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateModelMeta 更新模型
|
// UpdateModelMeta 更新模型
|
||||||
func UpdateModelMeta(c *gin.Context) {
|
func UpdateModelMeta(c *gin.Context) {
|
||||||
statusOnly := c.Query("status_only") == "true"
|
statusOnly := c.Query("status_only") == "true"
|
||||||
|
|
||||||
var m model.Model
|
var m model.Model
|
||||||
if err := c.ShouldBindJSON(&m); err != nil {
|
if err := c.ShouldBindJSON(&m); err != nil {
|
||||||
common.ApiError(c, err)
|
common.ApiError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if m.Id == 0 {
|
if m.Id == 0 {
|
||||||
common.ApiErrorMsg(c, "缺少模型 ID")
|
common.ApiErrorMsg(c, "缺少模型 ID")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if statusOnly {
|
if statusOnly {
|
||||||
// 只更新状态,防止误清空其他字段
|
// 只更新状态,防止误清空其他字段
|
||||||
if err := model.DB.Model(&model.Model{}).Where("id = ?", m.Id).Update("status", m.Status).Error; err != nil {
|
if err := model.DB.Model(&model.Model{}).Where("id = ?", m.Id).Update("status", m.Status).Error; err != nil {
|
||||||
common.ApiError(c, err)
|
common.ApiError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// 名称冲突检查
|
// 名称冲突检查
|
||||||
if dup, err := model.IsModelNameDuplicated(m.Id, m.ModelName); err != nil {
|
if dup, err := model.IsModelNameDuplicated(m.Id, m.ModelName); err != nil {
|
||||||
common.ApiError(c, err)
|
common.ApiError(c, err)
|
||||||
return
|
return
|
||||||
} else if dup {
|
} else if dup {
|
||||||
common.ApiErrorMsg(c, "模型名称已存在")
|
common.ApiErrorMsg(c, "模型名称已存在")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.Update(); err != nil {
|
if err := m.Update(); err != nil {
|
||||||
common.ApiError(c, err)
|
common.ApiError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
model.RefreshPricing()
|
model.RefreshPricing()
|
||||||
common.ApiSuccess(c, &m)
|
common.ApiSuccess(c, &m)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteModelMeta 删除模型
|
// DeleteModelMeta 删除模型
|
||||||
func DeleteModelMeta(c *gin.Context) {
|
func DeleteModelMeta(c *gin.Context) {
|
||||||
idStr := c.Param("id")
|
idStr := c.Param("id")
|
||||||
id, err := strconv.Atoi(idStr)
|
id, err := strconv.Atoi(idStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.ApiError(c, err)
|
common.ApiError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := model.DB.Delete(&model.Model{}, id).Error; err != nil {
|
if err := model.DB.Delete(&model.Model{}, id).Error; err != nil {
|
||||||
common.ApiError(c, err)
|
common.ApiError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
model.RefreshPricing()
|
model.RefreshPricing()
|
||||||
common.ApiSuccess(c, nil)
|
common.ApiSuccess(c, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 辅助函数:填充 Endpoints 和 BoundChannels 和 EnableGroups
|
// 辅助函数:填充 Endpoints 和 BoundChannels 和 EnableGroups
|
||||||
func fillModelExtra(m *model.Model) {
|
func fillModelExtra(m *model.Model) {
|
||||||
if m.Endpoints == "" {
|
if m.Endpoints == "" {
|
||||||
eps := model.GetModelSupportEndpointTypes(m.ModelName)
|
eps := model.GetModelSupportEndpointTypes(m.ModelName)
|
||||||
if b, err := json.Marshal(eps); err == nil {
|
if b, err := json.Marshal(eps); err == nil {
|
||||||
m.Endpoints = string(b)
|
m.Endpoints = string(b)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if channels, err := model.GetBoundChannels(m.ModelName); err == nil {
|
if channels, err := model.GetBoundChannels(m.ModelName); err == nil {
|
||||||
m.BoundChannels = channels
|
m.BoundChannels = channels
|
||||||
}
|
}
|
||||||
// 填充启用分组
|
// 填充启用分组
|
||||||
m.EnableGroups = model.GetModelEnableGroups(m.ModelName)
|
m.EnableGroups = model.GetModelEnableGroups(m.ModelName)
|
||||||
// 填充计费类型
|
// 填充计费类型
|
||||||
m.QuotaType = model.GetModelQuotaType(m.ModelName)
|
m.QuotaType = model.GetModelQuotaType(m.ModelName)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,90 +1,90 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GetPrefillGroups 获取预填组列表,可通过 ?type=xxx 过滤
|
// GetPrefillGroups 获取预填组列表,可通过 ?type=xxx 过滤
|
||||||
func GetPrefillGroups(c *gin.Context) {
|
func GetPrefillGroups(c *gin.Context) {
|
||||||
groupType := c.Query("type")
|
groupType := c.Query("type")
|
||||||
groups, err := model.GetAllPrefillGroups(groupType)
|
groups, err := model.GetAllPrefillGroups(groupType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.ApiError(c, err)
|
common.ApiError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
common.ApiSuccess(c, groups)
|
common.ApiSuccess(c, groups)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreatePrefillGroup 创建新的预填组
|
// CreatePrefillGroup 创建新的预填组
|
||||||
func CreatePrefillGroup(c *gin.Context) {
|
func CreatePrefillGroup(c *gin.Context) {
|
||||||
var g model.PrefillGroup
|
var g model.PrefillGroup
|
||||||
if err := c.ShouldBindJSON(&g); err != nil {
|
if err := c.ShouldBindJSON(&g); err != nil {
|
||||||
common.ApiError(c, err)
|
common.ApiError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if g.Name == "" || g.Type == "" {
|
if g.Name == "" || g.Type == "" {
|
||||||
common.ApiErrorMsg(c, "组名称和类型不能为空")
|
common.ApiErrorMsg(c, "组名称和类型不能为空")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// 创建前检查名称
|
// 创建前检查名称
|
||||||
if dup, err := model.IsPrefillGroupNameDuplicated(0, g.Name); err != nil {
|
if dup, err := model.IsPrefillGroupNameDuplicated(0, g.Name); err != nil {
|
||||||
common.ApiError(c, err)
|
common.ApiError(c, err)
|
||||||
return
|
return
|
||||||
} else if dup {
|
} else if dup {
|
||||||
common.ApiErrorMsg(c, "组名称已存在")
|
common.ApiErrorMsg(c, "组名称已存在")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := g.Insert(); err != nil {
|
if err := g.Insert(); err != nil {
|
||||||
common.ApiError(c, err)
|
common.ApiError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
common.ApiSuccess(c, &g)
|
common.ApiSuccess(c, &g)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdatePrefillGroup 更新预填组
|
// UpdatePrefillGroup 更新预填组
|
||||||
func UpdatePrefillGroup(c *gin.Context) {
|
func UpdatePrefillGroup(c *gin.Context) {
|
||||||
var g model.PrefillGroup
|
var g model.PrefillGroup
|
||||||
if err := c.ShouldBindJSON(&g); err != nil {
|
if err := c.ShouldBindJSON(&g); err != nil {
|
||||||
common.ApiError(c, err)
|
common.ApiError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if g.Id == 0 {
|
if g.Id == 0 {
|
||||||
common.ApiErrorMsg(c, "缺少组 ID")
|
common.ApiErrorMsg(c, "缺少组 ID")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// 名称冲突检查
|
// 名称冲突检查
|
||||||
if dup, err := model.IsPrefillGroupNameDuplicated(g.Id, g.Name); err != nil {
|
if dup, err := model.IsPrefillGroupNameDuplicated(g.Id, g.Name); err != nil {
|
||||||
common.ApiError(c, err)
|
common.ApiError(c, err)
|
||||||
return
|
return
|
||||||
} else if dup {
|
} else if dup {
|
||||||
common.ApiErrorMsg(c, "组名称已存在")
|
common.ApiErrorMsg(c, "组名称已存在")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := g.Update(); err != nil {
|
if err := g.Update(); err != nil {
|
||||||
common.ApiError(c, err)
|
common.ApiError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
common.ApiSuccess(c, &g)
|
common.ApiSuccess(c, &g)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeletePrefillGroup 删除预填组
|
// DeletePrefillGroup 删除预填组
|
||||||
func DeletePrefillGroup(c *gin.Context) {
|
func DeletePrefillGroup(c *gin.Context) {
|
||||||
idStr := c.Param("id")
|
idStr := c.Param("id")
|
||||||
id, err := strconv.Atoi(idStr)
|
id, err := strconv.Atoi(idStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.ApiError(c, err)
|
common.ApiError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := model.DeletePrefillGroupByID(id); err != nil {
|
if err := model.DeletePrefillGroupByID(id); err != nil {
|
||||||
common.ApiError(c, err)
|
common.ApiError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
common.ApiSuccess(c, nil)
|
common.ApiSuccess(c, nil)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -39,14 +39,14 @@ func GetPricing(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"data": pricing,
|
"data": pricing,
|
||||||
"vendors": model.GetVendors(),
|
"vendors": model.GetVendors(),
|
||||||
"group_ratio": groupRatio,
|
"group_ratio": groupRatio,
|
||||||
"usable_group": usableGroup,
|
"usable_group": usableGroup,
|
||||||
"supported_endpoint": model.GetSupportedEndpointMap(),
|
"supported_endpoint": model.GetSupportedEndpointMap(),
|
||||||
"auto_groups": setting.AutoGroups,
|
"auto_groups": setting.AutoGroups,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func ResetModelRatio(c *gin.Context) {
|
func ResetModelRatio(c *gin.Context) {
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ func Login(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查是否启用2FA
|
// 检查是否启用2FA
|
||||||
if model.IsTwoFAEnabled(user.Id) {
|
if model.IsTwoFAEnabled(user.Id) {
|
||||||
// 设置pending session,等待2FA验证
|
// 设置pending session,等待2FA验证
|
||||||
@@ -77,7 +77,7 @@ func Login(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"message": "请输入两步验证码",
|
"message": "请输入两步验证码",
|
||||||
"success": true,
|
"success": true,
|
||||||
@@ -87,7 +87,7 @@ func Login(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
setupLogin(&user, c)
|
setupLogin(&user, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,124 +1,124 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GetAllVendors 获取供应商列表(分页)
|
// GetAllVendors 获取供应商列表(分页)
|
||||||
func GetAllVendors(c *gin.Context) {
|
func GetAllVendors(c *gin.Context) {
|
||||||
pageInfo := common.GetPageQuery(c)
|
pageInfo := common.GetPageQuery(c)
|
||||||
vendors, err := model.GetAllVendors(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
vendors, err := model.GetAllVendors(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.ApiError(c, err)
|
common.ApiError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
var total int64
|
var total int64
|
||||||
model.DB.Model(&model.Vendor{}).Count(&total)
|
model.DB.Model(&model.Vendor{}).Count(&total)
|
||||||
pageInfo.SetTotal(int(total))
|
pageInfo.SetTotal(int(total))
|
||||||
pageInfo.SetItems(vendors)
|
pageInfo.SetItems(vendors)
|
||||||
common.ApiSuccess(c, pageInfo)
|
common.ApiSuccess(c, pageInfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SearchVendors 搜索供应商
|
// SearchVendors 搜索供应商
|
||||||
func SearchVendors(c *gin.Context) {
|
func SearchVendors(c *gin.Context) {
|
||||||
keyword := c.Query("keyword")
|
keyword := c.Query("keyword")
|
||||||
pageInfo := common.GetPageQuery(c)
|
pageInfo := common.GetPageQuery(c)
|
||||||
vendors, total, err := model.SearchVendors(keyword, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
vendors, total, err := model.SearchVendors(keyword, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.ApiError(c, err)
|
common.ApiError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
pageInfo.SetTotal(int(total))
|
pageInfo.SetTotal(int(total))
|
||||||
pageInfo.SetItems(vendors)
|
pageInfo.SetItems(vendors)
|
||||||
common.ApiSuccess(c, pageInfo)
|
common.ApiSuccess(c, pageInfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetVendorMeta 根据 ID 获取供应商
|
// GetVendorMeta 根据 ID 获取供应商
|
||||||
func GetVendorMeta(c *gin.Context) {
|
func GetVendorMeta(c *gin.Context) {
|
||||||
idStr := c.Param("id")
|
idStr := c.Param("id")
|
||||||
id, err := strconv.Atoi(idStr)
|
id, err := strconv.Atoi(idStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.ApiError(c, err)
|
common.ApiError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
v, err := model.GetVendorByID(id)
|
v, err := model.GetVendorByID(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.ApiError(c, err)
|
common.ApiError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
common.ApiSuccess(c, v)
|
common.ApiSuccess(c, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateVendorMeta 新建供应商
|
// CreateVendorMeta 新建供应商
|
||||||
func CreateVendorMeta(c *gin.Context) {
|
func CreateVendorMeta(c *gin.Context) {
|
||||||
var v model.Vendor
|
var v model.Vendor
|
||||||
if err := c.ShouldBindJSON(&v); err != nil {
|
if err := c.ShouldBindJSON(&v); err != nil {
|
||||||
common.ApiError(c, err)
|
common.ApiError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if v.Name == "" {
|
if v.Name == "" {
|
||||||
common.ApiErrorMsg(c, "供应商名称不能为空")
|
common.ApiErrorMsg(c, "供应商名称不能为空")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// 创建前先检查名称
|
// 创建前先检查名称
|
||||||
if dup, err := model.IsVendorNameDuplicated(0, v.Name); err != nil {
|
if dup, err := model.IsVendorNameDuplicated(0, v.Name); err != nil {
|
||||||
common.ApiError(c, err)
|
common.ApiError(c, err)
|
||||||
return
|
return
|
||||||
} else if dup {
|
} else if dup {
|
||||||
common.ApiErrorMsg(c, "供应商名称已存在")
|
common.ApiErrorMsg(c, "供应商名称已存在")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := v.Insert(); err != nil {
|
if err := v.Insert(); err != nil {
|
||||||
common.ApiError(c, err)
|
common.ApiError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
common.ApiSuccess(c, &v)
|
common.ApiSuccess(c, &v)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateVendorMeta 更新供应商
|
// UpdateVendorMeta 更新供应商
|
||||||
func UpdateVendorMeta(c *gin.Context) {
|
func UpdateVendorMeta(c *gin.Context) {
|
||||||
var v model.Vendor
|
var v model.Vendor
|
||||||
if err := c.ShouldBindJSON(&v); err != nil {
|
if err := c.ShouldBindJSON(&v); err != nil {
|
||||||
common.ApiError(c, err)
|
common.ApiError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if v.Id == 0 {
|
if v.Id == 0 {
|
||||||
common.ApiErrorMsg(c, "缺少供应商 ID")
|
common.ApiErrorMsg(c, "缺少供应商 ID")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// 名称冲突检查
|
// 名称冲突检查
|
||||||
if dup, err := model.IsVendorNameDuplicated(v.Id, v.Name); err != nil {
|
if dup, err := model.IsVendorNameDuplicated(v.Id, v.Name); err != nil {
|
||||||
common.ApiError(c, err)
|
common.ApiError(c, err)
|
||||||
return
|
return
|
||||||
} else if dup {
|
} else if dup {
|
||||||
common.ApiErrorMsg(c, "供应商名称已存在")
|
common.ApiErrorMsg(c, "供应商名称已存在")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := v.Update(); err != nil {
|
if err := v.Update(); err != nil {
|
||||||
common.ApiError(c, err)
|
common.ApiError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
common.ApiSuccess(c, &v)
|
common.ApiSuccess(c, &v)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteVendorMeta 删除供应商
|
// DeleteVendorMeta 删除供应商
|
||||||
func DeleteVendorMeta(c *gin.Context) {
|
func DeleteVendorMeta(c *gin.Context) {
|
||||||
idStr := c.Param("id")
|
idStr := c.Param("id")
|
||||||
id, err := strconv.Atoi(idStr)
|
id, err := strconv.Atoi(idStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.ApiError(c, err)
|
common.ApiError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := model.DB.Delete(&model.Vendor{}, id).Error; err != nil {
|
if err := model.DB.Delete(&model.Vendor{}, id).Error; err != nil {
|
||||||
common.ApiError(c, err)
|
common.ApiError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
common.ApiSuccess(c, nil)
|
common.ApiSuccess(c, nil)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,29 +2,29 @@ package model
|
|||||||
|
|
||||||
// GetMissingModels returns model names that are referenced in the system
|
// GetMissingModels returns model names that are referenced in the system
|
||||||
func GetMissingModels() ([]string, error) {
|
func GetMissingModels() ([]string, error) {
|
||||||
// 1. 获取所有已启用模型(去重)
|
// 1. 获取所有已启用模型(去重)
|
||||||
models := GetEnabledModels()
|
models := GetEnabledModels()
|
||||||
if len(models) == 0 {
|
if len(models) == 0 {
|
||||||
return []string{}, nil
|
return []string{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. 查询已有的元数据模型名
|
// 2. 查询已有的元数据模型名
|
||||||
var existing []string
|
var existing []string
|
||||||
if err := DB.Model(&Model{}).Where("model_name IN ?", models).Pluck("model_name", &existing).Error; err != nil {
|
if err := DB.Model(&Model{}).Where("model_name IN ?", models).Pluck("model_name", &existing).Error; err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
existingSet := make(map[string]struct{}, len(existing))
|
existingSet := make(map[string]struct{}, len(existing))
|
||||||
for _, e := range existing {
|
for _, e := range existing {
|
||||||
existingSet[e] = struct{}{}
|
existingSet[e] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. 收集缺失模型
|
// 3. 收集缺失模型
|
||||||
var missing []string
|
var missing []string
|
||||||
for _, name := range models {
|
for _, name := range models {
|
||||||
if _, ok := existingSet[name]; !ok {
|
if _, ok := existingSet[name]; !ok {
|
||||||
missing = append(missing, name)
|
missing = append(missing, name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return missing, nil
|
return missing, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,32 +3,32 @@ package model
|
|||||||
// GetModelEnableGroups 返回指定模型名称可用的用户分组列表。
|
// GetModelEnableGroups 返回指定模型名称可用的用户分组列表。
|
||||||
// 使用在 updatePricing() 中维护的缓存映射,O(1) 读取,适合高并发场景。
|
// 使用在 updatePricing() 中维护的缓存映射,O(1) 读取,适合高并发场景。
|
||||||
func GetModelEnableGroups(modelName string) []string {
|
func GetModelEnableGroups(modelName string) []string {
|
||||||
// 确保缓存最新
|
// 确保缓存最新
|
||||||
GetPricing()
|
GetPricing()
|
||||||
|
|
||||||
if modelName == "" {
|
if modelName == "" {
|
||||||
return make([]string, 0)
|
return make([]string, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
modelEnableGroupsLock.RLock()
|
modelEnableGroupsLock.RLock()
|
||||||
groups, ok := modelEnableGroups[modelName]
|
groups, ok := modelEnableGroups[modelName]
|
||||||
modelEnableGroupsLock.RUnlock()
|
modelEnableGroupsLock.RUnlock()
|
||||||
if !ok {
|
if !ok {
|
||||||
return make([]string, 0)
|
return make([]string, 0)
|
||||||
}
|
}
|
||||||
return groups
|
return groups
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetModelQuotaType 返回指定模型的计费类型(quota_type)。
|
// GetModelQuotaType 返回指定模型的计费类型(quota_type)。
|
||||||
// 同样使用缓存映射,避免每次遍历定价切片。
|
// 同样使用缓存映射,避免每次遍历定价切片。
|
||||||
func GetModelQuotaType(modelName string) int {
|
func GetModelQuotaType(modelName string) int {
|
||||||
GetPricing()
|
GetPricing()
|
||||||
|
|
||||||
modelEnableGroupsLock.RLock()
|
modelEnableGroupsLock.RLock()
|
||||||
quota, ok := modelQuotaTypeMap[modelName]
|
quota, ok := modelQuotaTypeMap[modelName]
|
||||||
modelEnableGroupsLock.RUnlock()
|
modelEnableGroupsLock.RUnlock()
|
||||||
if !ok {
|
if !ok {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
return quota
|
return quota
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Model 用于存储模型的元数据,例如描述、标签等
|
// Model 用于存储模型的元数据,例如描述、标签等
|
||||||
@@ -23,186 +23,186 @@ import (
|
|||||||
|
|
||||||
// 模型名称匹配规则
|
// 模型名称匹配规则
|
||||||
const (
|
const (
|
||||||
NameRuleExact = iota // 0 精确匹配
|
NameRuleExact = iota // 0 精确匹配
|
||||||
NameRulePrefix // 1 前缀匹配
|
NameRulePrefix // 1 前缀匹配
|
||||||
NameRuleContains // 2 包含匹配
|
NameRuleContains // 2 包含匹配
|
||||||
NameRuleSuffix // 3 后缀匹配
|
NameRuleSuffix // 3 后缀匹配
|
||||||
)
|
)
|
||||||
|
|
||||||
type BoundChannel struct {
|
type BoundChannel struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Type int `json:"type"`
|
Type int `json:"type"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Model struct {
|
type Model struct {
|
||||||
Id int `json:"id"`
|
Id int `json:"id"`
|
||||||
ModelName string `json:"model_name" gorm:"size:128;not null;uniqueIndex:uk_model_name,priority:1"`
|
ModelName string `json:"model_name" gorm:"size:128;not null;uniqueIndex:uk_model_name,priority:1"`
|
||||||
Description string `json:"description,omitempty" gorm:"type:text"`
|
Description string `json:"description,omitempty" gorm:"type:text"`
|
||||||
Icon string `json:"icon,omitempty" gorm:"type:varchar(128)"`
|
Icon string `json:"icon,omitempty" gorm:"type:varchar(128)"`
|
||||||
Tags string `json:"tags,omitempty" gorm:"type:varchar(255)"`
|
Tags string `json:"tags,omitempty" gorm:"type:varchar(255)"`
|
||||||
VendorID int `json:"vendor_id,omitempty" gorm:"index"`
|
VendorID int `json:"vendor_id,omitempty" gorm:"index"`
|
||||||
Endpoints string `json:"endpoints,omitempty" gorm:"type:text"`
|
Endpoints string `json:"endpoints,omitempty" gorm:"type:text"`
|
||||||
Status int `json:"status" gorm:"default:1"`
|
Status int `json:"status" gorm:"default:1"`
|
||||||
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||||||
UpdatedTime int64 `json:"updated_time" gorm:"bigint"`
|
UpdatedTime int64 `json:"updated_time" gorm:"bigint"`
|
||||||
DeletedAt gorm.DeletedAt `json:"-" gorm:"index;uniqueIndex:uk_model_name,priority:2"`
|
DeletedAt gorm.DeletedAt `json:"-" gorm:"index;uniqueIndex:uk_model_name,priority:2"`
|
||||||
|
|
||||||
BoundChannels []BoundChannel `json:"bound_channels,omitempty" gorm:"-"`
|
BoundChannels []BoundChannel `json:"bound_channels,omitempty" gorm:"-"`
|
||||||
EnableGroups []string `json:"enable_groups,omitempty" gorm:"-"`
|
EnableGroups []string `json:"enable_groups,omitempty" gorm:"-"`
|
||||||
QuotaType int `json:"quota_type" gorm:"-"`
|
QuotaType int `json:"quota_type" gorm:"-"`
|
||||||
NameRule int `json:"name_rule" gorm:"default:0"`
|
NameRule int `json:"name_rule" gorm:"default:0"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Insert 创建新的模型元数据记录
|
// Insert 创建新的模型元数据记录
|
||||||
func (mi *Model) Insert() error {
|
func (mi *Model) Insert() error {
|
||||||
now := common.GetTimestamp()
|
now := common.GetTimestamp()
|
||||||
mi.CreatedTime = now
|
mi.CreatedTime = now
|
||||||
mi.UpdatedTime = now
|
mi.UpdatedTime = now
|
||||||
return DB.Create(mi).Error
|
return DB.Create(mi).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsModelNameDuplicated 检查模型名称是否重复(排除自身 ID)
|
// IsModelNameDuplicated 检查模型名称是否重复(排除自身 ID)
|
||||||
func IsModelNameDuplicated(id int, name string) (bool, error) {
|
func IsModelNameDuplicated(id int, name string) (bool, error) {
|
||||||
if name == "" {
|
if name == "" {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
var cnt int64
|
var cnt int64
|
||||||
err := DB.Model(&Model{}).Where("model_name = ? AND id <> ?", name, id).Count(&cnt).Error
|
err := DB.Model(&Model{}).Where("model_name = ? AND id <> ?", name, id).Count(&cnt).Error
|
||||||
return cnt > 0, err
|
return cnt > 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update 更新现有模型记录
|
// Update 更新现有模型记录
|
||||||
func (mi *Model) Update() error {
|
func (mi *Model) Update() error {
|
||||||
mi.UpdatedTime = common.GetTimestamp()
|
mi.UpdatedTime = common.GetTimestamp()
|
||||||
// 使用 Session 配置并选择所有字段,允许零值(如空字符串)也能被更新
|
// 使用 Session 配置并选择所有字段,允许零值(如空字符串)也能被更新
|
||||||
return DB.Session(&gorm.Session{AllowGlobalUpdate: false, FullSaveAssociations: false}).
|
return DB.Session(&gorm.Session{AllowGlobalUpdate: false, FullSaveAssociations: false}).
|
||||||
Model(&Model{}).
|
Model(&Model{}).
|
||||||
Where("id = ?", mi.Id).
|
Where("id = ?", mi.Id).
|
||||||
Omit("created_time").
|
Omit("created_time").
|
||||||
Select("*").
|
Select("*").
|
||||||
Updates(mi).Error
|
Updates(mi).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete 软删除模型记录
|
// Delete 软删除模型记录
|
||||||
func (mi *Model) Delete() error {
|
func (mi *Model) Delete() error {
|
||||||
return DB.Delete(mi).Error
|
return DB.Delete(mi).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetModelByName 根据模型名称查询元数据
|
// GetModelByName 根据模型名称查询元数据
|
||||||
func GetModelByName(name string) (*Model, error) {
|
func GetModelByName(name string) (*Model, error) {
|
||||||
var mi Model
|
var mi Model
|
||||||
err := DB.Where("model_name = ?", name).First(&mi).Error
|
err := DB.Where("model_name = ?", name).First(&mi).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &mi, nil
|
return &mi, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetVendorModelCounts 统计每个供应商下模型数量(不受分页影响)
|
// GetVendorModelCounts 统计每个供应商下模型数量(不受分页影响)
|
||||||
func GetVendorModelCounts() (map[int64]int64, error) {
|
func GetVendorModelCounts() (map[int64]int64, error) {
|
||||||
var stats []struct {
|
var stats []struct {
|
||||||
VendorID int64
|
VendorID int64
|
||||||
Count int64
|
Count int64
|
||||||
}
|
}
|
||||||
if err := DB.Model(&Model{}).
|
if err := DB.Model(&Model{}).
|
||||||
Select("vendor_id as vendor_id, count(*) as count").
|
Select("vendor_id as vendor_id, count(*) as count").
|
||||||
Group("vendor_id").
|
Group("vendor_id").
|
||||||
Scan(&stats).Error; err != nil {
|
Scan(&stats).Error; err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
m := make(map[int64]int64, len(stats))
|
m := make(map[int64]int64, len(stats))
|
||||||
for _, s := range stats {
|
for _, s := range stats {
|
||||||
m[s.VendorID] = s.Count
|
m[s.VendorID] = s.Count
|
||||||
}
|
}
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAllModels 分页获取所有模型元数据
|
// GetAllModels 分页获取所有模型元数据
|
||||||
func GetAllModels(offset int, limit int) ([]*Model, error) {
|
func GetAllModels(offset int, limit int) ([]*Model, error) {
|
||||||
var models []*Model
|
var models []*Model
|
||||||
err := DB.Offset(offset).Limit(limit).Find(&models).Error
|
err := DB.Offset(offset).Limit(limit).Find(&models).Error
|
||||||
return models, err
|
return models, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetBoundChannels 查询支持该模型的渠道(名称+类型)
|
// GetBoundChannels 查询支持该模型的渠道(名称+类型)
|
||||||
func GetBoundChannels(modelName string) ([]BoundChannel, error) {
|
func GetBoundChannels(modelName string) ([]BoundChannel, error) {
|
||||||
var channels []BoundChannel
|
var channels []BoundChannel
|
||||||
err := DB.Table("channels").
|
err := DB.Table("channels").
|
||||||
Select("channels.name, channels.type").
|
Select("channels.name, channels.type").
|
||||||
Joins("join abilities on abilities.channel_id = channels.id").
|
Joins("join abilities on abilities.channel_id = channels.id").
|
||||||
Where("abilities.model = ? AND abilities.enabled = ?", modelName, true).
|
Where("abilities.model = ? AND abilities.enabled = ?", modelName, true).
|
||||||
Group("channels.id").
|
Group("channels.id").
|
||||||
Scan(&channels).Error
|
Scan(&channels).Error
|
||||||
return channels, err
|
return channels, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// FindModelByNameWithRule 根据模型名称和匹配规则查找模型元数据,优先级:精确 > 前缀 > 后缀 > 包含
|
// FindModelByNameWithRule 根据模型名称和匹配规则查找模型元数据,优先级:精确 > 前缀 > 后缀 > 包含
|
||||||
func FindModelByNameWithRule(name string) (*Model, error) {
|
func FindModelByNameWithRule(name string) (*Model, error) {
|
||||||
// 1. 精确匹配
|
// 1. 精确匹配
|
||||||
if m, err := GetModelByName(name); err == nil {
|
if m, err := GetModelByName(name); err == nil {
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
// 2. 规则匹配
|
// 2. 规则匹配
|
||||||
var models []*Model
|
var models []*Model
|
||||||
if err := DB.Where("name_rule <> ?", NameRuleExact).Find(&models).Error; err != nil {
|
if err := DB.Where("name_rule <> ?", NameRuleExact).Find(&models).Error; err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
var prefixMatch, suffixMatch, containsMatch *Model
|
var prefixMatch, suffixMatch, containsMatch *Model
|
||||||
for _, m := range models {
|
for _, m := range models {
|
||||||
switch m.NameRule {
|
switch m.NameRule {
|
||||||
case NameRulePrefix:
|
case NameRulePrefix:
|
||||||
if strings.HasPrefix(name, m.ModelName) {
|
if strings.HasPrefix(name, m.ModelName) {
|
||||||
if prefixMatch == nil || len(m.ModelName) > len(prefixMatch.ModelName) {
|
if prefixMatch == nil || len(m.ModelName) > len(prefixMatch.ModelName) {
|
||||||
prefixMatch = m
|
prefixMatch = m
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case NameRuleSuffix:
|
case NameRuleSuffix:
|
||||||
if strings.HasSuffix(name, m.ModelName) {
|
if strings.HasSuffix(name, m.ModelName) {
|
||||||
if suffixMatch == nil || len(m.ModelName) > len(suffixMatch.ModelName) {
|
if suffixMatch == nil || len(m.ModelName) > len(suffixMatch.ModelName) {
|
||||||
suffixMatch = m
|
suffixMatch = m
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case NameRuleContains:
|
case NameRuleContains:
|
||||||
if strings.Contains(name, m.ModelName) {
|
if strings.Contains(name, m.ModelName) {
|
||||||
if containsMatch == nil || len(m.ModelName) > len(containsMatch.ModelName) {
|
if containsMatch == nil || len(m.ModelName) > len(containsMatch.ModelName) {
|
||||||
containsMatch = m
|
containsMatch = m
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if prefixMatch != nil {
|
if prefixMatch != nil {
|
||||||
return prefixMatch, nil
|
return prefixMatch, nil
|
||||||
}
|
}
|
||||||
if suffixMatch != nil {
|
if suffixMatch != nil {
|
||||||
return suffixMatch, nil
|
return suffixMatch, nil
|
||||||
}
|
}
|
||||||
if containsMatch != nil {
|
if containsMatch != nil {
|
||||||
return containsMatch, nil
|
return containsMatch, nil
|
||||||
}
|
}
|
||||||
return nil, gorm.ErrRecordNotFound
|
return nil, gorm.ErrRecordNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
// SearchModels 根据关键词和供应商搜索模型,支持分页
|
// SearchModels 根据关键词和供应商搜索模型,支持分页
|
||||||
func SearchModels(keyword string, vendor string, offset int, limit int) ([]*Model, int64, error) {
|
func SearchModels(keyword string, vendor string, offset int, limit int) ([]*Model, int64, error) {
|
||||||
var models []*Model
|
var models []*Model
|
||||||
db := DB.Model(&Model{})
|
db := DB.Model(&Model{})
|
||||||
if keyword != "" {
|
if keyword != "" {
|
||||||
like := "%" + keyword + "%"
|
like := "%" + keyword + "%"
|
||||||
db = db.Where("model_name LIKE ? OR description LIKE ? OR tags LIKE ?", like, like, like)
|
db = db.Where("model_name LIKE ? OR description LIKE ? OR tags LIKE ?", like, like, like)
|
||||||
}
|
}
|
||||||
if vendor != "" {
|
if vendor != "" {
|
||||||
// 如果是数字,按供应商 ID 精确匹配;否则按名称模糊匹配
|
// 如果是数字,按供应商 ID 精确匹配;否则按名称模糊匹配
|
||||||
if vid, err := strconv.Atoi(vendor); err == nil {
|
if vid, err := strconv.Atoi(vendor); err == nil {
|
||||||
db = db.Where("models.vendor_id = ?", vid)
|
db = db.Where("models.vendor_id = ?", vid)
|
||||||
} else {
|
} else {
|
||||||
db = db.Joins("JOIN vendors ON vendors.id = models.vendor_id").Where("vendors.name LIKE ?", "%"+vendor+"%")
|
db = db.Joins("JOIN vendors ON vendors.id = models.vendor_id").Where("vendors.name LIKE ?", "%"+vendor+"%")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
var total int64
|
var total int64
|
||||||
err := db.Count(&total).Error
|
err := db.Count(&total).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
err = db.Offset(offset).Limit(limit).Order("models.id DESC").Find(&models).Error
|
err = db.Offset(offset).Limit(limit).Order("models.id DESC").Find(&models).Error
|
||||||
return models, total, err
|
return models, total, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"database/sql/driver"
|
||||||
"database/sql/driver"
|
"encoding/json"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
// PrefillGroup 用于存储可复用的“组”信息,例如模型组、标签组、端点组等。
|
// PrefillGroup 用于存储可复用的“组”信息,例如模型组、标签组、端点组等。
|
||||||
@@ -20,107 +20,107 @@ type JSONValue json.RawMessage
|
|||||||
|
|
||||||
// Value 实现 driver.Valuer 接口,用于数据库写入
|
// Value 实现 driver.Valuer 接口,用于数据库写入
|
||||||
func (j JSONValue) Value() (driver.Value, error) {
|
func (j JSONValue) Value() (driver.Value, error) {
|
||||||
if j == nil {
|
if j == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
return []byte(j), nil
|
return []byte(j), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan 实现 sql.Scanner 接口,兼容不同驱动返回的类型
|
// Scan 实现 sql.Scanner 接口,兼容不同驱动返回的类型
|
||||||
func (j *JSONValue) Scan(value interface{}) error {
|
func (j *JSONValue) Scan(value interface{}) error {
|
||||||
switch v := value.(type) {
|
switch v := value.(type) {
|
||||||
case nil:
|
case nil:
|
||||||
*j = nil
|
*j = nil
|
||||||
return nil
|
return nil
|
||||||
case []byte:
|
case []byte:
|
||||||
// 拷贝底层字节,避免保留底层缓冲区
|
// 拷贝底层字节,避免保留底层缓冲区
|
||||||
b := make([]byte, len(v))
|
b := make([]byte, len(v))
|
||||||
copy(b, v)
|
copy(b, v)
|
||||||
*j = JSONValue(b)
|
*j = JSONValue(b)
|
||||||
return nil
|
return nil
|
||||||
case string:
|
case string:
|
||||||
*j = JSONValue([]byte(v))
|
*j = JSONValue([]byte(v))
|
||||||
return nil
|
return nil
|
||||||
default:
|
default:
|
||||||
// 其他类型尝试序列化为 JSON
|
// 其他类型尝试序列化为 JSON
|
||||||
b, err := json.Marshal(v)
|
b, err := json.Marshal(v)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
*j = JSONValue(b)
|
*j = JSONValue(b)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON 确保在对外编码时与 json.RawMessage 行为一致
|
// MarshalJSON 确保在对外编码时与 json.RawMessage 行为一致
|
||||||
func (j JSONValue) MarshalJSON() ([]byte, error) {
|
func (j JSONValue) MarshalJSON() ([]byte, error) {
|
||||||
if j == nil {
|
if j == nil {
|
||||||
return []byte("null"), nil
|
return []byte("null"), nil
|
||||||
}
|
}
|
||||||
return j, nil
|
return j, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnmarshalJSON 确保在对外解码时与 json.RawMessage 行为一致
|
// UnmarshalJSON 确保在对外解码时与 json.RawMessage 行为一致
|
||||||
func (j *JSONValue) UnmarshalJSON(data []byte) error {
|
func (j *JSONValue) UnmarshalJSON(data []byte) error {
|
||||||
if data == nil {
|
if data == nil {
|
||||||
*j = nil
|
*j = nil
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
b := make([]byte, len(data))
|
b := make([]byte, len(data))
|
||||||
copy(b, data)
|
copy(b, data)
|
||||||
*j = JSONValue(b)
|
*j = JSONValue(b)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type PrefillGroup struct {
|
type PrefillGroup struct {
|
||||||
Id int `json:"id"`
|
Id int `json:"id"`
|
||||||
Name string `json:"name" gorm:"size:64;not null;uniqueIndex:uk_prefill_name,where:deleted_at IS NULL"`
|
Name string `json:"name" gorm:"size:64;not null;uniqueIndex:uk_prefill_name,where:deleted_at IS NULL"`
|
||||||
Type string `json:"type" gorm:"size:32;index;not null"`
|
Type string `json:"type" gorm:"size:32;index;not null"`
|
||||||
Items JSONValue `json:"items" gorm:"type:json"`
|
Items JSONValue `json:"items" gorm:"type:json"`
|
||||||
Description string `json:"description,omitempty" gorm:"type:varchar(255)"`
|
Description string `json:"description,omitempty" gorm:"type:varchar(255)"`
|
||||||
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||||||
UpdatedTime int64 `json:"updated_time" gorm:"bigint"`
|
UpdatedTime int64 `json:"updated_time" gorm:"bigint"`
|
||||||
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
|
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Insert 新建组
|
// Insert 新建组
|
||||||
func (g *PrefillGroup) Insert() error {
|
func (g *PrefillGroup) Insert() error {
|
||||||
now := common.GetTimestamp()
|
now := common.GetTimestamp()
|
||||||
g.CreatedTime = now
|
g.CreatedTime = now
|
||||||
g.UpdatedTime = now
|
g.UpdatedTime = now
|
||||||
return DB.Create(g).Error
|
return DB.Create(g).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsPrefillGroupNameDuplicated 检查组名称是否重复(排除自身 ID)
|
// IsPrefillGroupNameDuplicated 检查组名称是否重复(排除自身 ID)
|
||||||
func IsPrefillGroupNameDuplicated(id int, name string) (bool, error) {
|
func IsPrefillGroupNameDuplicated(id int, name string) (bool, error) {
|
||||||
if name == "" {
|
if name == "" {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
var cnt int64
|
var cnt int64
|
||||||
err := DB.Model(&PrefillGroup{}).Where("name = ? AND id <> ?", name, id).Count(&cnt).Error
|
err := DB.Model(&PrefillGroup{}).Where("name = ? AND id <> ?", name, id).Count(&cnt).Error
|
||||||
return cnt > 0, err
|
return cnt > 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update 更新组
|
// Update 更新组
|
||||||
func (g *PrefillGroup) Update() error {
|
func (g *PrefillGroup) Update() error {
|
||||||
g.UpdatedTime = common.GetTimestamp()
|
g.UpdatedTime = common.GetTimestamp()
|
||||||
return DB.Save(g).Error
|
return DB.Save(g).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteByID 根据 ID 删除组
|
// DeleteByID 根据 ID 删除组
|
||||||
func DeletePrefillGroupByID(id int) error {
|
func DeletePrefillGroupByID(id int) error {
|
||||||
return DB.Delete(&PrefillGroup{}, id).Error
|
return DB.Delete(&PrefillGroup{}, id).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAllPrefillGroups 获取全部组,可按类型过滤(为空则返回全部)
|
// GetAllPrefillGroups 获取全部组,可按类型过滤(为空则返回全部)
|
||||||
func GetAllPrefillGroups(groupType string) ([]*PrefillGroup, error) {
|
func GetAllPrefillGroups(groupType string) ([]*PrefillGroup, error) {
|
||||||
var groups []*PrefillGroup
|
var groups []*PrefillGroup
|
||||||
query := DB.Model(&PrefillGroup{})
|
query := DB.Model(&PrefillGroup{})
|
||||||
if groupType != "" {
|
if groupType != "" {
|
||||||
query = query.Where("type = ?", groupType)
|
query = query.Where("type = ?", groupType)
|
||||||
}
|
}
|
||||||
if err := query.Order("updated_time DESC").Find(&groups).Error; err != nil {
|
if err := query.Order("updated_time DESC").Find(&groups).Error; err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return groups, nil
|
return groups, nil
|
||||||
}
|
}
|
||||||
|
|||||||
282
model/pricing.go
282
model/pricing.go
@@ -1,31 +1,31 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/setting/ratio_setting"
|
"one-api/setting/ratio_setting"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Pricing struct {
|
type Pricing struct {
|
||||||
ModelName string `json:"model_name"`
|
ModelName string `json:"model_name"`
|
||||||
Description string `json:"description,omitempty"`
|
Description string `json:"description,omitempty"`
|
||||||
Icon string `json:"icon,omitempty"`
|
Icon string `json:"icon,omitempty"`
|
||||||
Tags string `json:"tags,omitempty"`
|
Tags string `json:"tags,omitempty"`
|
||||||
VendorID int `json:"vendor_id,omitempty"`
|
VendorID int `json:"vendor_id,omitempty"`
|
||||||
QuotaType int `json:"quota_type"`
|
QuotaType int `json:"quota_type"`
|
||||||
ModelRatio float64 `json:"model_ratio"`
|
ModelRatio float64 `json:"model_ratio"`
|
||||||
ModelPrice float64 `json:"model_price"`
|
ModelPrice float64 `json:"model_price"`
|
||||||
OwnerBy string `json:"owner_by"`
|
OwnerBy string `json:"owner_by"`
|
||||||
CompletionRatio float64 `json:"completion_ratio"`
|
CompletionRatio float64 `json:"completion_ratio"`
|
||||||
EnableGroup []string `json:"enable_groups"`
|
EnableGroup []string `json:"enable_groups"`
|
||||||
SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
|
SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type PricingVendor struct {
|
type PricingVendor struct {
|
||||||
@@ -36,11 +36,11 @@ type PricingVendor struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
pricingMap []Pricing
|
pricingMap []Pricing
|
||||||
vendorsList []PricingVendor
|
vendorsList []PricingVendor
|
||||||
supportedEndpointMap map[string]common.EndpointInfo
|
supportedEndpointMap map[string]common.EndpointInfo
|
||||||
lastGetPricingTime time.Time
|
lastGetPricingTime time.Time
|
||||||
updatePricingLock sync.Mutex
|
updatePricingLock sync.Mutex
|
||||||
|
|
||||||
// 缓存映射:模型名 -> 启用分组 / 计费类型
|
// 缓存映射:模型名 -> 启用分组 / 计费类型
|
||||||
modelEnableGroups = make(map[string][]string)
|
modelEnableGroups = make(map[string][]string)
|
||||||
@@ -122,19 +122,19 @@ func updatePricing() {
|
|||||||
for _, m := range prefixList {
|
for _, m := range prefixList {
|
||||||
for _, pricingModel := range enableAbilities {
|
for _, pricingModel := range enableAbilities {
|
||||||
if strings.HasPrefix(pricingModel.Model, m.ModelName) {
|
if strings.HasPrefix(pricingModel.Model, m.ModelName) {
|
||||||
if _, exists := metaMap[pricingModel.Model]; !exists {
|
if _, exists := metaMap[pricingModel.Model]; !exists {
|
||||||
metaMap[pricingModel.Model] = m
|
metaMap[pricingModel.Model] = m
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for _, m := range suffixList {
|
for _, m := range suffixList {
|
||||||
for _, pricingModel := range enableAbilities {
|
for _, pricingModel := range enableAbilities {
|
||||||
if strings.HasSuffix(pricingModel.Model, m.ModelName) {
|
if strings.HasSuffix(pricingModel.Model, m.ModelName) {
|
||||||
if _, exists := metaMap[pricingModel.Model]; !exists {
|
if _, exists := metaMap[pricingModel.Model]; !exists {
|
||||||
metaMap[pricingModel.Model] = m
|
metaMap[pricingModel.Model] = m
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for _, m := range containsList {
|
for _, m := range containsList {
|
||||||
@@ -180,34 +180,34 @@ func updatePricing() {
|
|||||||
//这里使用切片而不是Set,因为一个模型可能支持多个端点类型,并且第一个端点是优先使用端点
|
//这里使用切片而不是Set,因为一个模型可能支持多个端点类型,并且第一个端点是优先使用端点
|
||||||
modelSupportEndpointsStr := make(map[string][]string)
|
modelSupportEndpointsStr := make(map[string][]string)
|
||||||
|
|
||||||
// 先根据已有能力填充原生端点
|
// 先根据已有能力填充原生端点
|
||||||
for _, ability := range enableAbilities {
|
for _, ability := range enableAbilities {
|
||||||
endpoints := modelSupportEndpointsStr[ability.Model]
|
endpoints := modelSupportEndpointsStr[ability.Model]
|
||||||
channelTypes := common.GetEndpointTypesByChannelType(ability.ChannelType, ability.Model)
|
channelTypes := common.GetEndpointTypesByChannelType(ability.ChannelType, ability.Model)
|
||||||
for _, channelType := range channelTypes {
|
for _, channelType := range channelTypes {
|
||||||
if !common.StringsContains(endpoints, string(channelType)) {
|
if !common.StringsContains(endpoints, string(channelType)) {
|
||||||
endpoints = append(endpoints, string(channelType))
|
endpoints = append(endpoints, string(channelType))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
modelSupportEndpointsStr[ability.Model] = endpoints
|
modelSupportEndpointsStr[ability.Model] = endpoints
|
||||||
}
|
}
|
||||||
|
|
||||||
// 再补充模型自定义端点
|
// 再补充模型自定义端点
|
||||||
for modelName, meta := range metaMap {
|
for modelName, meta := range metaMap {
|
||||||
if strings.TrimSpace(meta.Endpoints) == "" {
|
if strings.TrimSpace(meta.Endpoints) == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
var raw map[string]interface{}
|
var raw map[string]interface{}
|
||||||
if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil {
|
if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil {
|
||||||
endpoints := modelSupportEndpointsStr[modelName]
|
endpoints := modelSupportEndpointsStr[modelName]
|
||||||
for k := range raw {
|
for k := range raw {
|
||||||
if !common.StringsContains(endpoints, k) {
|
if !common.StringsContains(endpoints, k) {
|
||||||
endpoints = append(endpoints, k)
|
endpoints = append(endpoints, k)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
modelSupportEndpointsStr[modelName] = endpoints
|
modelSupportEndpointsStr[modelName] = endpoints
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
|
modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
|
||||||
for model, endpoints := range modelSupportEndpointsStr {
|
for model, endpoints := range modelSupportEndpointsStr {
|
||||||
@@ -217,93 +217,93 @@ func updatePricing() {
|
|||||||
supportedEndpoints = append(supportedEndpoints, endpointType)
|
supportedEndpoints = append(supportedEndpoints, endpointType)
|
||||||
}
|
}
|
||||||
modelSupportEndpointTypes[model] = supportedEndpoints
|
modelSupportEndpointTypes[model] = supportedEndpoints
|
||||||
}
|
}
|
||||||
|
|
||||||
// 构建全局 supportedEndpointMap(默认 + 自定义覆盖)
|
// 构建全局 supportedEndpointMap(默认 + 自定义覆盖)
|
||||||
supportedEndpointMap = make(map[string]common.EndpointInfo)
|
supportedEndpointMap = make(map[string]common.EndpointInfo)
|
||||||
// 1. 默认端点
|
// 1. 默认端点
|
||||||
for _, endpoints := range modelSupportEndpointTypes {
|
for _, endpoints := range modelSupportEndpointTypes {
|
||||||
for _, et := range endpoints {
|
for _, et := range endpoints {
|
||||||
if info, ok := common.GetDefaultEndpointInfo(et); ok {
|
if info, ok := common.GetDefaultEndpointInfo(et); ok {
|
||||||
if _, exists := supportedEndpointMap[string(et)]; !exists {
|
if _, exists := supportedEndpointMap[string(et)]; !exists {
|
||||||
supportedEndpointMap[string(et)] = info
|
supportedEndpointMap[string(et)] = info
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// 2. 自定义端点(models 表)覆盖默认
|
// 2. 自定义端点(models 表)覆盖默认
|
||||||
for _, meta := range metaMap {
|
for _, meta := range metaMap {
|
||||||
if strings.TrimSpace(meta.Endpoints) == "" {
|
if strings.TrimSpace(meta.Endpoints) == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
var raw map[string]interface{}
|
var raw map[string]interface{}
|
||||||
if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil {
|
if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil {
|
||||||
for k, v := range raw {
|
for k, v := range raw {
|
||||||
switch val := v.(type) {
|
switch val := v.(type) {
|
||||||
case string:
|
case string:
|
||||||
supportedEndpointMap[k] = common.EndpointInfo{Path: val, Method: "POST"}
|
supportedEndpointMap[k] = common.EndpointInfo{Path: val, Method: "POST"}
|
||||||
case map[string]interface{}:
|
case map[string]interface{}:
|
||||||
ep := common.EndpointInfo{Method: "POST"}
|
ep := common.EndpointInfo{Method: "POST"}
|
||||||
if p, ok := val["path"].(string); ok {
|
if p, ok := val["path"].(string); ok {
|
||||||
ep.Path = p
|
ep.Path = p
|
||||||
}
|
}
|
||||||
if m, ok := val["method"].(string); ok {
|
if m, ok := val["method"].(string); ok {
|
||||||
ep.Method = strings.ToUpper(m)
|
ep.Method = strings.ToUpper(m)
|
||||||
}
|
}
|
||||||
supportedEndpointMap[k] = ep
|
supportedEndpointMap[k] = ep
|
||||||
default:
|
default:
|
||||||
// ignore unsupported types
|
// ignore unsupported types
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pricingMap = make([]Pricing, 0)
|
pricingMap = make([]Pricing, 0)
|
||||||
for model, groups := range modelGroupsMap {
|
for model, groups := range modelGroupsMap {
|
||||||
pricing := Pricing{
|
pricing := Pricing{
|
||||||
ModelName: model,
|
ModelName: model,
|
||||||
EnableGroup: groups.Items(),
|
EnableGroup: groups.Items(),
|
||||||
SupportedEndpointTypes: modelSupportEndpointTypes[model],
|
SupportedEndpointTypes: modelSupportEndpointTypes[model],
|
||||||
}
|
}
|
||||||
|
|
||||||
// 补充模型元数据(描述、标签、供应商、状态)
|
// 补充模型元数据(描述、标签、供应商、状态)
|
||||||
if meta, ok := metaMap[model]; ok {
|
if meta, ok := metaMap[model]; ok {
|
||||||
// 若模型被禁用(status!=1),则直接跳过,不返回给前端
|
// 若模型被禁用(status!=1),则直接跳过,不返回给前端
|
||||||
if meta.Status != 1 {
|
if meta.Status != 1 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
pricing.Description = meta.Description
|
pricing.Description = meta.Description
|
||||||
pricing.Icon = meta.Icon
|
pricing.Icon = meta.Icon
|
||||||
pricing.Tags = meta.Tags
|
pricing.Tags = meta.Tags
|
||||||
pricing.VendorID = meta.VendorID
|
pricing.VendorID = meta.VendorID
|
||||||
}
|
}
|
||||||
modelPrice, findPrice := ratio_setting.GetModelPrice(model, false)
|
modelPrice, findPrice := ratio_setting.GetModelPrice(model, false)
|
||||||
if findPrice {
|
if findPrice {
|
||||||
pricing.ModelPrice = modelPrice
|
pricing.ModelPrice = modelPrice
|
||||||
pricing.QuotaType = 1
|
pricing.QuotaType = 1
|
||||||
} else {
|
} else {
|
||||||
modelRatio, _, _ := ratio_setting.GetModelRatio(model)
|
modelRatio, _, _ := ratio_setting.GetModelRatio(model)
|
||||||
pricing.ModelRatio = modelRatio
|
pricing.ModelRatio = modelRatio
|
||||||
pricing.CompletionRatio = ratio_setting.GetCompletionRatio(model)
|
pricing.CompletionRatio = ratio_setting.GetCompletionRatio(model)
|
||||||
pricing.QuotaType = 0
|
pricing.QuotaType = 0
|
||||||
}
|
}
|
||||||
pricingMap = append(pricingMap, pricing)
|
pricingMap = append(pricingMap, pricing)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 刷新缓存映射,供高并发快速查询
|
// 刷新缓存映射,供高并发快速查询
|
||||||
modelEnableGroupsLock.Lock()
|
modelEnableGroupsLock.Lock()
|
||||||
modelEnableGroups = make(map[string][]string)
|
modelEnableGroups = make(map[string][]string)
|
||||||
modelQuotaTypeMap = make(map[string]int)
|
modelQuotaTypeMap = make(map[string]int)
|
||||||
for _, p := range pricingMap {
|
for _, p := range pricingMap {
|
||||||
modelEnableGroups[p.ModelName] = p.EnableGroup
|
modelEnableGroups[p.ModelName] = p.EnableGroup
|
||||||
modelQuotaTypeMap[p.ModelName] = p.QuotaType
|
modelQuotaTypeMap[p.ModelName] = p.QuotaType
|
||||||
}
|
}
|
||||||
modelEnableGroupsLock.Unlock()
|
modelEnableGroupsLock.Unlock()
|
||||||
|
|
||||||
lastGetPricingTime = time.Now()
|
lastGetPricingTime = time.Now()
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSupportedEndpointMap 返回全局端点到路径的映射
|
// GetSupportedEndpointMap 返回全局端点到路径的映射
|
||||||
func GetSupportedEndpointMap() map[string]common.EndpointInfo {
|
func GetSupportedEndpointMap() map[string]common.EndpointInfo {
|
||||||
return supportedEndpointMap
|
return supportedEndpointMap
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,11 +4,11 @@ package model
|
|||||||
// 该方法用于需要最新数据的内部管理 API,
|
// 该方法用于需要最新数据的内部管理 API,
|
||||||
// 因此会绕过默认的 1 分钟延迟刷新。
|
// 因此会绕过默认的 1 分钟延迟刷新。
|
||||||
func RefreshPricing() {
|
func RefreshPricing() {
|
||||||
updatePricingLock.Lock()
|
updatePricingLock.Lock()
|
||||||
defer updatePricingLock.Unlock()
|
defer updatePricingLock.Unlock()
|
||||||
|
|
||||||
modelSupportEndpointsLock.Lock()
|
modelSupportEndpointsLock.Lock()
|
||||||
defer modelSupportEndpointsLock.Unlock()
|
defer modelSupportEndpointsLock.Unlock()
|
||||||
|
|
||||||
updatePricing()
|
updatePricing()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Vendor 用于存储供应商信息,供模型引用
|
// Vendor 用于存储供应商信息,供模型引用
|
||||||
@@ -13,76 +13,76 @@ import (
|
|||||||
// 本表同样遵循 3NF 设计范式
|
// 本表同样遵循 3NF 设计范式
|
||||||
|
|
||||||
type Vendor struct {
|
type Vendor struct {
|
||||||
Id int `json:"id"`
|
Id int `json:"id"`
|
||||||
Name string `json:"name" gorm:"size:128;not null;uniqueIndex:uk_vendor_name,priority:1"`
|
Name string `json:"name" gorm:"size:128;not null;uniqueIndex:uk_vendor_name,priority:1"`
|
||||||
Description string `json:"description,omitempty" gorm:"type:text"`
|
Description string `json:"description,omitempty" gorm:"type:text"`
|
||||||
Icon string `json:"icon,omitempty" gorm:"type:varchar(128)"`
|
Icon string `json:"icon,omitempty" gorm:"type:varchar(128)"`
|
||||||
Status int `json:"status" gorm:"default:1"`
|
Status int `json:"status" gorm:"default:1"`
|
||||||
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||||||
UpdatedTime int64 `json:"updated_time" gorm:"bigint"`
|
UpdatedTime int64 `json:"updated_time" gorm:"bigint"`
|
||||||
DeletedAt gorm.DeletedAt `json:"-" gorm:"index;uniqueIndex:uk_vendor_name,priority:2"`
|
DeletedAt gorm.DeletedAt `json:"-" gorm:"index;uniqueIndex:uk_vendor_name,priority:2"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Insert 创建新的供应商记录
|
// Insert 创建新的供应商记录
|
||||||
func (v *Vendor) Insert() error {
|
func (v *Vendor) Insert() error {
|
||||||
now := common.GetTimestamp()
|
now := common.GetTimestamp()
|
||||||
v.CreatedTime = now
|
v.CreatedTime = now
|
||||||
v.UpdatedTime = now
|
v.UpdatedTime = now
|
||||||
return DB.Create(v).Error
|
return DB.Create(v).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsVendorNameDuplicated 检查供应商名称是否重复(排除自身 ID)
|
// IsVendorNameDuplicated 检查供应商名称是否重复(排除自身 ID)
|
||||||
func IsVendorNameDuplicated(id int, name string) (bool, error) {
|
func IsVendorNameDuplicated(id int, name string) (bool, error) {
|
||||||
if name == "" {
|
if name == "" {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
var cnt int64
|
var cnt int64
|
||||||
err := DB.Model(&Vendor{}).Where("name = ? AND id <> ?", name, id).Count(&cnt).Error
|
err := DB.Model(&Vendor{}).Where("name = ? AND id <> ?", name, id).Count(&cnt).Error
|
||||||
return cnt > 0, err
|
return cnt > 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update 更新供应商记录
|
// Update 更新供应商记录
|
||||||
func (v *Vendor) Update() error {
|
func (v *Vendor) Update() error {
|
||||||
v.UpdatedTime = common.GetTimestamp()
|
v.UpdatedTime = common.GetTimestamp()
|
||||||
return DB.Save(v).Error
|
return DB.Save(v).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete 软删除供应商
|
// Delete 软删除供应商
|
||||||
func (v *Vendor) Delete() error {
|
func (v *Vendor) Delete() error {
|
||||||
return DB.Delete(v).Error
|
return DB.Delete(v).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetVendorByID 根据 ID 获取供应商
|
// GetVendorByID 根据 ID 获取供应商
|
||||||
func GetVendorByID(id int) (*Vendor, error) {
|
func GetVendorByID(id int) (*Vendor, error) {
|
||||||
var v Vendor
|
var v Vendor
|
||||||
err := DB.First(&v, id).Error
|
err := DB.First(&v, id).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &v, nil
|
return &v, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAllVendors 获取全部供应商(分页)
|
// GetAllVendors 获取全部供应商(分页)
|
||||||
func GetAllVendors(offset int, limit int) ([]*Vendor, error) {
|
func GetAllVendors(offset int, limit int) ([]*Vendor, error) {
|
||||||
var vendors []*Vendor
|
var vendors []*Vendor
|
||||||
err := DB.Offset(offset).Limit(limit).Find(&vendors).Error
|
err := DB.Offset(offset).Limit(limit).Find(&vendors).Error
|
||||||
return vendors, err
|
return vendors, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// SearchVendors 按关键字搜索供应商
|
// SearchVendors 按关键字搜索供应商
|
||||||
func SearchVendors(keyword string, offset int, limit int) ([]*Vendor, int64, error) {
|
func SearchVendors(keyword string, offset int, limit int) ([]*Vendor, int64, error) {
|
||||||
db := DB.Model(&Vendor{})
|
db := DB.Model(&Vendor{})
|
||||||
if keyword != "" {
|
if keyword != "" {
|
||||||
like := "%" + keyword + "%"
|
like := "%" + keyword + "%"
|
||||||
db = db.Where("name LIKE ? OR description LIKE ?", like, like)
|
db = db.Where("name LIKE ? OR description LIKE ?", like, like)
|
||||||
}
|
}
|
||||||
var total int64
|
var total int64
|
||||||
if err := db.Count(&total).Error; err != nil {
|
if err := db.Count(&total).Error; err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
var vendors []*Vendor
|
var vendors []*Vendor
|
||||||
if err := db.Offset(offset).Limit(limit).Order("id DESC").Find(&vendors).Error; err != nil {
|
if err := db.Offset(offset).Limit(limit).Order("id DESC").Find(&vendors).Error; err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
return vendors, total, nil
|
return vendors, total, nil
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user