🖼️ chore: format code file

This commit is contained in:
t0ng7u
2025-08-10 12:11:31 +08:00
parent ca1f3c6e4c
commit 1d578b73ce
14 changed files with 776 additions and 776 deletions

View File

@@ -11,22 +11,22 @@ import "one-api/constant"
// 例如:{"path":"/v1/chat/completions","method":"POST"} // 例如:{"path":"/v1/chat/completions","method":"POST"}
type EndpointInfo struct { type EndpointInfo struct {
Path string `json:"path"` Path string `json:"path"`
Method string `json:"method"` Method string `json:"method"`
} }
// defaultEndpointInfoMap 保存内置端点的默认 Path 与 Method // defaultEndpointInfoMap 保存内置端点的默认 Path 与 Method
var defaultEndpointInfoMap = map[constant.EndpointType]EndpointInfo{ var defaultEndpointInfoMap = map[constant.EndpointType]EndpointInfo{
constant.EndpointTypeOpenAI: {Path: "/v1/chat/completions", Method: "POST"}, constant.EndpointTypeOpenAI: {Path: "/v1/chat/completions", Method: "POST"},
constant.EndpointTypeOpenAIResponse: {Path: "/v1/responses", Method: "POST"}, constant.EndpointTypeOpenAIResponse: {Path: "/v1/responses", Method: "POST"},
constant.EndpointTypeAnthropic: {Path: "/v1/messages", Method: "POST"}, constant.EndpointTypeAnthropic: {Path: "/v1/messages", Method: "POST"},
constant.EndpointTypeGemini: {Path: "/v1beta/models/{model}:generateContent", Method: "POST"}, constant.EndpointTypeGemini: {Path: "/v1beta/models/{model}:generateContent", Method: "POST"},
constant.EndpointTypeJinaRerank: {Path: "/rerank", Method: "POST"}, constant.EndpointTypeJinaRerank: {Path: "/rerank", Method: "POST"},
constant.EndpointTypeImageGeneration: {Path: "/v1/images/generations", Method: "POST"}, constant.EndpointTypeImageGeneration: {Path: "/v1/images/generations", Method: "POST"},
} }
// GetDefaultEndpointInfo 返回指定端点类型的默认信息以及是否存在 // GetDefaultEndpointInfo 返回指定端点类型的默认信息以及是否存在
func GetDefaultEndpointInfo(et constant.EndpointType) (EndpointInfo, bool) { func GetDefaultEndpointInfo(et constant.EndpointType) (EndpointInfo, bool) {
info, ok := defaultEndpointInfoMap[et] info, ok := defaultEndpointInfoMap[et]
return info, ok return info, ok
} }

View File

@@ -1,27 +1,27 @@
package controller package controller
import ( import (
"net/http" "net/http"
"one-api/model" "one-api/model"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
// GetMissingModels returns the list of model names that are referenced by channels // GetMissingModels returns the list of model names that are referenced by channels
// but do not have corresponding records in the models meta table. // but do not have corresponding records in the models meta table.
// This helps administrators quickly discover models that need configuration. // This helps administrators quickly discover models that need configuration.
func GetMissingModels(c *gin.Context) { func GetMissingModels(c *gin.Context) {
missing, err := model.GetMissingModels() missing, err := model.GetMissingModels()
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": err.Error(), "message": err.Error(),
}) })
return return
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"data": missing, "data": missing,
}) })
} }

View File

@@ -1,178 +1,178 @@
package controller package controller
import ( import (
"encoding/json" "encoding/json"
"strconv" "strconv"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
// GetAllModelsMeta 获取模型列表(分页) // GetAllModelsMeta 获取模型列表(分页)
func GetAllModelsMeta(c *gin.Context) { func GetAllModelsMeta(c *gin.Context) {
pageInfo := common.GetPageQuery(c) pageInfo := common.GetPageQuery(c)
modelsMeta, err := model.GetAllModels(pageInfo.GetStartIdx(), pageInfo.GetPageSize()) modelsMeta, err := model.GetAllModels(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
if err != nil { if err != nil {
common.ApiError(c, err) common.ApiError(c, err)
return return
} }
// 填充附加字段 // 填充附加字段
for _, m := range modelsMeta { for _, m := range modelsMeta {
fillModelExtra(m) fillModelExtra(m)
} }
var total int64 var total int64
model.DB.Model(&model.Model{}).Count(&total) model.DB.Model(&model.Model{}).Count(&total)
// 统计供应商计数(全部数据,不受分页影响) // 统计供应商计数(全部数据,不受分页影响)
vendorCounts, _ := model.GetVendorModelCounts() vendorCounts, _ := model.GetVendorModelCounts()
pageInfo.SetTotal(int(total)) pageInfo.SetTotal(int(total))
pageInfo.SetItems(modelsMeta) pageInfo.SetItems(modelsMeta)
common.ApiSuccess(c, gin.H{ common.ApiSuccess(c, gin.H{
"items": modelsMeta, "items": modelsMeta,
"total": total, "total": total,
"page": pageInfo.GetPage(), "page": pageInfo.GetPage(),
"page_size": pageInfo.GetPageSize(), "page_size": pageInfo.GetPageSize(),
"vendor_counts": vendorCounts, "vendor_counts": vendorCounts,
}) })
} }
// SearchModelsMeta 搜索模型列表 // SearchModelsMeta 搜索模型列表
func SearchModelsMeta(c *gin.Context) { func SearchModelsMeta(c *gin.Context) {
keyword := c.Query("keyword") keyword := c.Query("keyword")
vendor := c.Query("vendor") vendor := c.Query("vendor")
pageInfo := common.GetPageQuery(c) pageInfo := common.GetPageQuery(c)
modelsMeta, total, err := model.SearchModels(keyword, vendor, pageInfo.GetStartIdx(), pageInfo.GetPageSize()) modelsMeta, total, err := model.SearchModels(keyword, vendor, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
if err != nil { if err != nil {
common.ApiError(c, err) common.ApiError(c, err)
return return
} }
for _, m := range modelsMeta { for _, m := range modelsMeta {
fillModelExtra(m) fillModelExtra(m)
} }
pageInfo.SetTotal(int(total)) pageInfo.SetTotal(int(total))
pageInfo.SetItems(modelsMeta) pageInfo.SetItems(modelsMeta)
common.ApiSuccess(c, pageInfo) common.ApiSuccess(c, pageInfo)
} }
// GetModelMeta 根据 ID 获取单条模型信息 // GetModelMeta 根据 ID 获取单条模型信息
func GetModelMeta(c *gin.Context) { func GetModelMeta(c *gin.Context) {
idStr := c.Param("id") idStr := c.Param("id")
id, err := strconv.Atoi(idStr) id, err := strconv.Atoi(idStr)
if err != nil { if err != nil {
common.ApiError(c, err) common.ApiError(c, err)
return return
} }
var m model.Model var m model.Model
if err := model.DB.First(&m, id).Error; err != nil { if err := model.DB.First(&m, id).Error; err != nil {
common.ApiError(c, err) common.ApiError(c, err)
return return
} }
fillModelExtra(&m) fillModelExtra(&m)
common.ApiSuccess(c, &m) common.ApiSuccess(c, &m)
} }
// CreateModelMeta 新建模型 // CreateModelMeta 新建模型
func CreateModelMeta(c *gin.Context) { func CreateModelMeta(c *gin.Context) {
var m model.Model var m model.Model
if err := c.ShouldBindJSON(&m); err != nil { if err := c.ShouldBindJSON(&m); err != nil {
common.ApiError(c, err) common.ApiError(c, err)
return return
} }
if m.ModelName == "" { if m.ModelName == "" {
common.ApiErrorMsg(c, "模型名称不能为空") common.ApiErrorMsg(c, "模型名称不能为空")
return return
} }
// 名称冲突检查 // 名称冲突检查
if dup, err := model.IsModelNameDuplicated(0, m.ModelName); err != nil { if dup, err := model.IsModelNameDuplicated(0, m.ModelName); err != nil {
common.ApiError(c, err) common.ApiError(c, err)
return return
} else if dup { } else if dup {
common.ApiErrorMsg(c, "模型名称已存在") common.ApiErrorMsg(c, "模型名称已存在")
return return
} }
if err := m.Insert(); err != nil { if err := m.Insert(); err != nil {
common.ApiError(c, err) common.ApiError(c, err)
return return
} }
model.RefreshPricing() model.RefreshPricing()
common.ApiSuccess(c, &m) common.ApiSuccess(c, &m)
} }
// UpdateModelMeta 更新模型 // UpdateModelMeta 更新模型
func UpdateModelMeta(c *gin.Context) { func UpdateModelMeta(c *gin.Context) {
statusOnly := c.Query("status_only") == "true" statusOnly := c.Query("status_only") == "true"
var m model.Model var m model.Model
if err := c.ShouldBindJSON(&m); err != nil { if err := c.ShouldBindJSON(&m); err != nil {
common.ApiError(c, err) common.ApiError(c, err)
return return
} }
if m.Id == 0 { if m.Id == 0 {
common.ApiErrorMsg(c, "缺少模型 ID") common.ApiErrorMsg(c, "缺少模型 ID")
return return
} }
if statusOnly { if statusOnly {
// 只更新状态,防止误清空其他字段 // 只更新状态,防止误清空其他字段
if err := model.DB.Model(&model.Model{}).Where("id = ?", m.Id).Update("status", m.Status).Error; err != nil { if err := model.DB.Model(&model.Model{}).Where("id = ?", m.Id).Update("status", m.Status).Error; err != nil {
common.ApiError(c, err) common.ApiError(c, err)
return return
} }
} else { } else {
// 名称冲突检查 // 名称冲突检查
if dup, err := model.IsModelNameDuplicated(m.Id, m.ModelName); err != nil { if dup, err := model.IsModelNameDuplicated(m.Id, m.ModelName); err != nil {
common.ApiError(c, err) common.ApiError(c, err)
return return
} else if dup { } else if dup {
common.ApiErrorMsg(c, "模型名称已存在") common.ApiErrorMsg(c, "模型名称已存在")
return return
} }
if err := m.Update(); err != nil { if err := m.Update(); err != nil {
common.ApiError(c, err) common.ApiError(c, err)
return return
} }
} }
model.RefreshPricing() model.RefreshPricing()
common.ApiSuccess(c, &m) common.ApiSuccess(c, &m)
} }
// DeleteModelMeta 删除模型 // DeleteModelMeta 删除模型
func DeleteModelMeta(c *gin.Context) { func DeleteModelMeta(c *gin.Context) {
idStr := c.Param("id") idStr := c.Param("id")
id, err := strconv.Atoi(idStr) id, err := strconv.Atoi(idStr)
if err != nil { if err != nil {
common.ApiError(c, err) common.ApiError(c, err)
return return
} }
if err := model.DB.Delete(&model.Model{}, id).Error; err != nil { if err := model.DB.Delete(&model.Model{}, id).Error; err != nil {
common.ApiError(c, err) common.ApiError(c, err)
return return
} }
model.RefreshPricing() model.RefreshPricing()
common.ApiSuccess(c, nil) common.ApiSuccess(c, nil)
} }
// 辅助函数:填充 Endpoints 和 BoundChannels 和 EnableGroups // 辅助函数:填充 Endpoints 和 BoundChannels 和 EnableGroups
func fillModelExtra(m *model.Model) { func fillModelExtra(m *model.Model) {
if m.Endpoints == "" { if m.Endpoints == "" {
eps := model.GetModelSupportEndpointTypes(m.ModelName) eps := model.GetModelSupportEndpointTypes(m.ModelName)
if b, err := json.Marshal(eps); err == nil { if b, err := json.Marshal(eps); err == nil {
m.Endpoints = string(b) m.Endpoints = string(b)
} }
} }
if channels, err := model.GetBoundChannels(m.ModelName); err == nil { if channels, err := model.GetBoundChannels(m.ModelName); err == nil {
m.BoundChannels = channels m.BoundChannels = channels
} }
// 填充启用分组 // 填充启用分组
m.EnableGroups = model.GetModelEnableGroups(m.ModelName) m.EnableGroups = model.GetModelEnableGroups(m.ModelName)
// 填充计费类型 // 填充计费类型
m.QuotaType = model.GetModelQuotaType(m.ModelName) m.QuotaType = model.GetModelQuotaType(m.ModelName)
} }

View File

@@ -1,90 +1,90 @@
package controller package controller
import ( import (
"strconv" "strconv"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
// GetPrefillGroups 获取预填组列表,可通过 ?type=xxx 过滤 // GetPrefillGroups 获取预填组列表,可通过 ?type=xxx 过滤
func GetPrefillGroups(c *gin.Context) { func GetPrefillGroups(c *gin.Context) {
groupType := c.Query("type") groupType := c.Query("type")
groups, err := model.GetAllPrefillGroups(groupType) groups, err := model.GetAllPrefillGroups(groupType)
if err != nil { if err != nil {
common.ApiError(c, err) common.ApiError(c, err)
return return
} }
common.ApiSuccess(c, groups) common.ApiSuccess(c, groups)
} }
// CreatePrefillGroup 创建新的预填组 // CreatePrefillGroup 创建新的预填组
func CreatePrefillGroup(c *gin.Context) { func CreatePrefillGroup(c *gin.Context) {
var g model.PrefillGroup var g model.PrefillGroup
if err := c.ShouldBindJSON(&g); err != nil { if err := c.ShouldBindJSON(&g); err != nil {
common.ApiError(c, err) common.ApiError(c, err)
return return
} }
if g.Name == "" || g.Type == "" { if g.Name == "" || g.Type == "" {
common.ApiErrorMsg(c, "组名称和类型不能为空") common.ApiErrorMsg(c, "组名称和类型不能为空")
return return
} }
// 创建前检查名称 // 创建前检查名称
if dup, err := model.IsPrefillGroupNameDuplicated(0, g.Name); err != nil { if dup, err := model.IsPrefillGroupNameDuplicated(0, g.Name); err != nil {
common.ApiError(c, err) common.ApiError(c, err)
return return
} else if dup { } else if dup {
common.ApiErrorMsg(c, "组名称已存在") common.ApiErrorMsg(c, "组名称已存在")
return return
} }
if err := g.Insert(); err != nil { if err := g.Insert(); err != nil {
common.ApiError(c, err) common.ApiError(c, err)
return return
} }
common.ApiSuccess(c, &g) common.ApiSuccess(c, &g)
} }
// UpdatePrefillGroup 更新预填组 // UpdatePrefillGroup 更新预填组
func UpdatePrefillGroup(c *gin.Context) { func UpdatePrefillGroup(c *gin.Context) {
var g model.PrefillGroup var g model.PrefillGroup
if err := c.ShouldBindJSON(&g); err != nil { if err := c.ShouldBindJSON(&g); err != nil {
common.ApiError(c, err) common.ApiError(c, err)
return return
} }
if g.Id == 0 { if g.Id == 0 {
common.ApiErrorMsg(c, "缺少组 ID") common.ApiErrorMsg(c, "缺少组 ID")
return return
} }
// 名称冲突检查 // 名称冲突检查
if dup, err := model.IsPrefillGroupNameDuplicated(g.Id, g.Name); err != nil { if dup, err := model.IsPrefillGroupNameDuplicated(g.Id, g.Name); err != nil {
common.ApiError(c, err) common.ApiError(c, err)
return return
} else if dup { } else if dup {
common.ApiErrorMsg(c, "组名称已存在") common.ApiErrorMsg(c, "组名称已存在")
return return
} }
if err := g.Update(); err != nil { if err := g.Update(); err != nil {
common.ApiError(c, err) common.ApiError(c, err)
return return
} }
common.ApiSuccess(c, &g) common.ApiSuccess(c, &g)
} }
// DeletePrefillGroup 删除预填组 // DeletePrefillGroup 删除预填组
func DeletePrefillGroup(c *gin.Context) { func DeletePrefillGroup(c *gin.Context) {
idStr := c.Param("id") idStr := c.Param("id")
id, err := strconv.Atoi(idStr) id, err := strconv.Atoi(idStr)
if err != nil { if err != nil {
common.ApiError(c, err) common.ApiError(c, err)
return return
} }
if err := model.DeletePrefillGroupByID(id); err != nil { if err := model.DeletePrefillGroupByID(id); err != nil {
common.ApiError(c, err) common.ApiError(c, err)
return return
} }
common.ApiSuccess(c, nil) common.ApiSuccess(c, nil)
} }

View File

@@ -39,14 +39,14 @@ func GetPricing(c *gin.Context) {
} }
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"success": true, "success": true,
"data": pricing, "data": pricing,
"vendors": model.GetVendors(), "vendors": model.GetVendors(),
"group_ratio": groupRatio, "group_ratio": groupRatio,
"usable_group": usableGroup, "usable_group": usableGroup,
"supported_endpoint": model.GetSupportedEndpointMap(), "supported_endpoint": model.GetSupportedEndpointMap(),
"auto_groups": setting.AutoGroups, "auto_groups": setting.AutoGroups,
}) })
} }
func ResetModelRatio(c *gin.Context) { func ResetModelRatio(c *gin.Context) {

View File

@@ -62,7 +62,7 @@ func Login(c *gin.Context) {
}) })
return return
} }
// 检查是否启用2FA // 检查是否启用2FA
if model.IsTwoFAEnabled(user.Id) { if model.IsTwoFAEnabled(user.Id) {
// 设置pending session等待2FA验证 // 设置pending session等待2FA验证
@@ -77,7 +77,7 @@ func Login(c *gin.Context) {
}) })
return return
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "请输入两步验证码", "message": "请输入两步验证码",
"success": true, "success": true,
@@ -87,7 +87,7 @@ func Login(c *gin.Context) {
}) })
return return
} }
setupLogin(&user, c) setupLogin(&user, c)
} }

View File

@@ -1,124 +1,124 @@
package controller package controller
import ( import (
"strconv" "strconv"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
// GetAllVendors 获取供应商列表(分页) // GetAllVendors 获取供应商列表(分页)
func GetAllVendors(c *gin.Context) { func GetAllVendors(c *gin.Context) {
pageInfo := common.GetPageQuery(c) pageInfo := common.GetPageQuery(c)
vendors, err := model.GetAllVendors(pageInfo.GetStartIdx(), pageInfo.GetPageSize()) vendors, err := model.GetAllVendors(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
if err != nil { if err != nil {
common.ApiError(c, err) common.ApiError(c, err)
return return
} }
var total int64 var total int64
model.DB.Model(&model.Vendor{}).Count(&total) model.DB.Model(&model.Vendor{}).Count(&total)
pageInfo.SetTotal(int(total)) pageInfo.SetTotal(int(total))
pageInfo.SetItems(vendors) pageInfo.SetItems(vendors)
common.ApiSuccess(c, pageInfo) common.ApiSuccess(c, pageInfo)
} }
// SearchVendors 搜索供应商 // SearchVendors 搜索供应商
func SearchVendors(c *gin.Context) { func SearchVendors(c *gin.Context) {
keyword := c.Query("keyword") keyword := c.Query("keyword")
pageInfo := common.GetPageQuery(c) pageInfo := common.GetPageQuery(c)
vendors, total, err := model.SearchVendors(keyword, pageInfo.GetStartIdx(), pageInfo.GetPageSize()) vendors, total, err := model.SearchVendors(keyword, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
if err != nil { if err != nil {
common.ApiError(c, err) common.ApiError(c, err)
return return
} }
pageInfo.SetTotal(int(total)) pageInfo.SetTotal(int(total))
pageInfo.SetItems(vendors) pageInfo.SetItems(vendors)
common.ApiSuccess(c, pageInfo) common.ApiSuccess(c, pageInfo)
} }
// GetVendorMeta 根据 ID 获取供应商 // GetVendorMeta 根据 ID 获取供应商
func GetVendorMeta(c *gin.Context) { func GetVendorMeta(c *gin.Context) {
idStr := c.Param("id") idStr := c.Param("id")
id, err := strconv.Atoi(idStr) id, err := strconv.Atoi(idStr)
if err != nil { if err != nil {
common.ApiError(c, err) common.ApiError(c, err)
return return
} }
v, err := model.GetVendorByID(id) v, err := model.GetVendorByID(id)
if err != nil { if err != nil {
common.ApiError(c, err) common.ApiError(c, err)
return return
} }
common.ApiSuccess(c, v) common.ApiSuccess(c, v)
} }
// CreateVendorMeta 新建供应商 // CreateVendorMeta 新建供应商
func CreateVendorMeta(c *gin.Context) { func CreateVendorMeta(c *gin.Context) {
var v model.Vendor var v model.Vendor
if err := c.ShouldBindJSON(&v); err != nil { if err := c.ShouldBindJSON(&v); err != nil {
common.ApiError(c, err) common.ApiError(c, err)
return return
} }
if v.Name == "" { if v.Name == "" {
common.ApiErrorMsg(c, "供应商名称不能为空") common.ApiErrorMsg(c, "供应商名称不能为空")
return return
} }
// 创建前先检查名称 // 创建前先检查名称
if dup, err := model.IsVendorNameDuplicated(0, v.Name); err != nil { if dup, err := model.IsVendorNameDuplicated(0, v.Name); err != nil {
common.ApiError(c, err) common.ApiError(c, err)
return return
} else if dup { } else if dup {
common.ApiErrorMsg(c, "供应商名称已存在") common.ApiErrorMsg(c, "供应商名称已存在")
return return
} }
if err := v.Insert(); err != nil { if err := v.Insert(); err != nil {
common.ApiError(c, err) common.ApiError(c, err)
return return
} }
common.ApiSuccess(c, &v) common.ApiSuccess(c, &v)
} }
// UpdateVendorMeta 更新供应商 // UpdateVendorMeta 更新供应商
func UpdateVendorMeta(c *gin.Context) { func UpdateVendorMeta(c *gin.Context) {
var v model.Vendor var v model.Vendor
if err := c.ShouldBindJSON(&v); err != nil { if err := c.ShouldBindJSON(&v); err != nil {
common.ApiError(c, err) common.ApiError(c, err)
return return
} }
if v.Id == 0 { if v.Id == 0 {
common.ApiErrorMsg(c, "缺少供应商 ID") common.ApiErrorMsg(c, "缺少供应商 ID")
return return
} }
// 名称冲突检查 // 名称冲突检查
if dup, err := model.IsVendorNameDuplicated(v.Id, v.Name); err != nil { if dup, err := model.IsVendorNameDuplicated(v.Id, v.Name); err != nil {
common.ApiError(c, err) common.ApiError(c, err)
return return
} else if dup { } else if dup {
common.ApiErrorMsg(c, "供应商名称已存在") common.ApiErrorMsg(c, "供应商名称已存在")
return return
} }
if err := v.Update(); err != nil { if err := v.Update(); err != nil {
common.ApiError(c, err) common.ApiError(c, err)
return return
} }
common.ApiSuccess(c, &v) common.ApiSuccess(c, &v)
} }
// DeleteVendorMeta 删除供应商 // DeleteVendorMeta 删除供应商
func DeleteVendorMeta(c *gin.Context) { func DeleteVendorMeta(c *gin.Context) {
idStr := c.Param("id") idStr := c.Param("id")
id, err := strconv.Atoi(idStr) id, err := strconv.Atoi(idStr)
if err != nil { if err != nil {
common.ApiError(c, err) common.ApiError(c, err)
return return
} }
if err := model.DB.Delete(&model.Vendor{}, id).Error; err != nil { if err := model.DB.Delete(&model.Vendor{}, id).Error; err != nil {
common.ApiError(c, err) common.ApiError(c, err)
return return
} }
common.ApiSuccess(c, nil) common.ApiSuccess(c, nil)
} }

View File

@@ -2,29 +2,29 @@ package model
// GetMissingModels returns model names that are referenced in the system // GetMissingModels returns model names that are referenced in the system
func GetMissingModels() ([]string, error) { func GetMissingModels() ([]string, error) {
// 1. 获取所有已启用模型(去重) // 1. 获取所有已启用模型(去重)
models := GetEnabledModels() models := GetEnabledModels()
if len(models) == 0 { if len(models) == 0 {
return []string{}, nil return []string{}, nil
} }
// 2. 查询已有的元数据模型名 // 2. 查询已有的元数据模型名
var existing []string var existing []string
if err := DB.Model(&Model{}).Where("model_name IN ?", models).Pluck("model_name", &existing).Error; err != nil { if err := DB.Model(&Model{}).Where("model_name IN ?", models).Pluck("model_name", &existing).Error; err != nil {
return nil, err return nil, err
} }
existingSet := make(map[string]struct{}, len(existing)) existingSet := make(map[string]struct{}, len(existing))
for _, e := range existing { for _, e := range existing {
existingSet[e] = struct{}{} existingSet[e] = struct{}{}
} }
// 3. 收集缺失模型 // 3. 收集缺失模型
var missing []string var missing []string
for _, name := range models { for _, name := range models {
if _, ok := existingSet[name]; !ok { if _, ok := existingSet[name]; !ok {
missing = append(missing, name) missing = append(missing, name)
} }
} }
return missing, nil return missing, nil
} }

View File

@@ -3,32 +3,32 @@ package model
// GetModelEnableGroups 返回指定模型名称可用的用户分组列表。 // GetModelEnableGroups 返回指定模型名称可用的用户分组列表。
// 使用在 updatePricing() 中维护的缓存映射O(1) 读取,适合高并发场景。 // 使用在 updatePricing() 中维护的缓存映射O(1) 读取,适合高并发场景。
func GetModelEnableGroups(modelName string) []string { func GetModelEnableGroups(modelName string) []string {
// 确保缓存最新 // 确保缓存最新
GetPricing() GetPricing()
if modelName == "" { if modelName == "" {
return make([]string, 0) return make([]string, 0)
} }
modelEnableGroupsLock.RLock() modelEnableGroupsLock.RLock()
groups, ok := modelEnableGroups[modelName] groups, ok := modelEnableGroups[modelName]
modelEnableGroupsLock.RUnlock() modelEnableGroupsLock.RUnlock()
if !ok { if !ok {
return make([]string, 0) return make([]string, 0)
} }
return groups return groups
} }
// GetModelQuotaType 返回指定模型的计费类型quota_type // GetModelQuotaType 返回指定模型的计费类型quota_type
// 同样使用缓存映射,避免每次遍历定价切片。 // 同样使用缓存映射,避免每次遍历定价切片。
func GetModelQuotaType(modelName string) int { func GetModelQuotaType(modelName string) int {
GetPricing() GetPricing()
modelEnableGroupsLock.RLock() modelEnableGroupsLock.RLock()
quota, ok := modelQuotaTypeMap[modelName] quota, ok := modelQuotaTypeMap[modelName]
modelEnableGroupsLock.RUnlock() modelEnableGroupsLock.RUnlock()
if !ok { if !ok {
return 0 return 0
} }
return quota return quota
} }

View File

@@ -1,11 +1,11 @@
package model package model
import ( import (
"one-api/common" "one-api/common"
"strconv" "strconv"
"strings" "strings"
"gorm.io/gorm" "gorm.io/gorm"
) )
// Model 用于存储模型的元数据,例如描述、标签等 // Model 用于存储模型的元数据,例如描述、标签等
@@ -23,186 +23,186 @@ import (
// 模型名称匹配规则 // 模型名称匹配规则
const ( const (
NameRuleExact = iota // 0 精确匹配 NameRuleExact = iota // 0 精确匹配
NameRulePrefix // 1 前缀匹配 NameRulePrefix // 1 前缀匹配
NameRuleContains // 2 包含匹配 NameRuleContains // 2 包含匹配
NameRuleSuffix // 3 后缀匹配 NameRuleSuffix // 3 后缀匹配
) )
type BoundChannel struct { type BoundChannel struct {
Name string `json:"name"` Name string `json:"name"`
Type int `json:"type"` Type int `json:"type"`
} }
type Model struct { type Model struct {
Id int `json:"id"` Id int `json:"id"`
ModelName string `json:"model_name" gorm:"size:128;not null;uniqueIndex:uk_model_name,priority:1"` ModelName string `json:"model_name" gorm:"size:128;not null;uniqueIndex:uk_model_name,priority:1"`
Description string `json:"description,omitempty" gorm:"type:text"` Description string `json:"description,omitempty" gorm:"type:text"`
Icon string `json:"icon,omitempty" gorm:"type:varchar(128)"` Icon string `json:"icon,omitempty" gorm:"type:varchar(128)"`
Tags string `json:"tags,omitempty" gorm:"type:varchar(255)"` Tags string `json:"tags,omitempty" gorm:"type:varchar(255)"`
VendorID int `json:"vendor_id,omitempty" gorm:"index"` VendorID int `json:"vendor_id,omitempty" gorm:"index"`
Endpoints string `json:"endpoints,omitempty" gorm:"type:text"` Endpoints string `json:"endpoints,omitempty" gorm:"type:text"`
Status int `json:"status" gorm:"default:1"` Status int `json:"status" gorm:"default:1"`
CreatedTime int64 `json:"created_time" gorm:"bigint"` CreatedTime int64 `json:"created_time" gorm:"bigint"`
UpdatedTime int64 `json:"updated_time" gorm:"bigint"` UpdatedTime int64 `json:"updated_time" gorm:"bigint"`
DeletedAt gorm.DeletedAt `json:"-" gorm:"index;uniqueIndex:uk_model_name,priority:2"` DeletedAt gorm.DeletedAt `json:"-" gorm:"index;uniqueIndex:uk_model_name,priority:2"`
BoundChannels []BoundChannel `json:"bound_channels,omitempty" gorm:"-"` BoundChannels []BoundChannel `json:"bound_channels,omitempty" gorm:"-"`
EnableGroups []string `json:"enable_groups,omitempty" gorm:"-"` EnableGroups []string `json:"enable_groups,omitempty" gorm:"-"`
QuotaType int `json:"quota_type" gorm:"-"` QuotaType int `json:"quota_type" gorm:"-"`
NameRule int `json:"name_rule" gorm:"default:0"` NameRule int `json:"name_rule" gorm:"default:0"`
} }
// Insert 创建新的模型元数据记录 // Insert 创建新的模型元数据记录
func (mi *Model) Insert() error { func (mi *Model) Insert() error {
now := common.GetTimestamp() now := common.GetTimestamp()
mi.CreatedTime = now mi.CreatedTime = now
mi.UpdatedTime = now mi.UpdatedTime = now
return DB.Create(mi).Error return DB.Create(mi).Error
} }
// IsModelNameDuplicated 检查模型名称是否重复(排除自身 ID // IsModelNameDuplicated 检查模型名称是否重复(排除自身 ID
func IsModelNameDuplicated(id int, name string) (bool, error) { func IsModelNameDuplicated(id int, name string) (bool, error) {
if name == "" { if name == "" {
return false, nil return false, nil
} }
var cnt int64 var cnt int64
err := DB.Model(&Model{}).Where("model_name = ? AND id <> ?", name, id).Count(&cnt).Error err := DB.Model(&Model{}).Where("model_name = ? AND id <> ?", name, id).Count(&cnt).Error
return cnt > 0, err return cnt > 0, err
} }
// Update 更新现有模型记录 // Update 更新现有模型记录
func (mi *Model) Update() error { func (mi *Model) Update() error {
mi.UpdatedTime = common.GetTimestamp() mi.UpdatedTime = common.GetTimestamp()
// 使用 Session 配置并选择所有字段,允许零值(如空字符串)也能被更新 // 使用 Session 配置并选择所有字段,允许零值(如空字符串)也能被更新
return DB.Session(&gorm.Session{AllowGlobalUpdate: false, FullSaveAssociations: false}). return DB.Session(&gorm.Session{AllowGlobalUpdate: false, FullSaveAssociations: false}).
Model(&Model{}). Model(&Model{}).
Where("id = ?", mi.Id). Where("id = ?", mi.Id).
Omit("created_time"). Omit("created_time").
Select("*"). Select("*").
Updates(mi).Error Updates(mi).Error
} }
// Delete 软删除模型记录 // Delete 软删除模型记录
func (mi *Model) Delete() error { func (mi *Model) Delete() error {
return DB.Delete(mi).Error return DB.Delete(mi).Error
} }
// GetModelByName 根据模型名称查询元数据 // GetModelByName 根据模型名称查询元数据
func GetModelByName(name string) (*Model, error) { func GetModelByName(name string) (*Model, error) {
var mi Model var mi Model
err := DB.Where("model_name = ?", name).First(&mi).Error err := DB.Where("model_name = ?", name).First(&mi).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &mi, nil return &mi, nil
} }
// GetVendorModelCounts 统计每个供应商下模型数量(不受分页影响) // GetVendorModelCounts 统计每个供应商下模型数量(不受分页影响)
func GetVendorModelCounts() (map[int64]int64, error) { func GetVendorModelCounts() (map[int64]int64, error) {
var stats []struct { var stats []struct {
VendorID int64 VendorID int64
Count int64 Count int64
} }
if err := DB.Model(&Model{}). if err := DB.Model(&Model{}).
Select("vendor_id as vendor_id, count(*) as count"). Select("vendor_id as vendor_id, count(*) as count").
Group("vendor_id"). Group("vendor_id").
Scan(&stats).Error; err != nil { Scan(&stats).Error; err != nil {
return nil, err return nil, err
} }
m := make(map[int64]int64, len(stats)) m := make(map[int64]int64, len(stats))
for _, s := range stats { for _, s := range stats {
m[s.VendorID] = s.Count m[s.VendorID] = s.Count
} }
return m, nil return m, nil
} }
// GetAllModels 分页获取所有模型元数据 // GetAllModels 分页获取所有模型元数据
func GetAllModels(offset int, limit int) ([]*Model, error) { func GetAllModels(offset int, limit int) ([]*Model, error) {
var models []*Model var models []*Model
err := DB.Offset(offset).Limit(limit).Find(&models).Error err := DB.Offset(offset).Limit(limit).Find(&models).Error
return models, err return models, err
} }
// GetBoundChannels 查询支持该模型的渠道(名称+类型) // GetBoundChannels 查询支持该模型的渠道(名称+类型)
func GetBoundChannels(modelName string) ([]BoundChannel, error) { func GetBoundChannels(modelName string) ([]BoundChannel, error) {
var channels []BoundChannel var channels []BoundChannel
err := DB.Table("channels"). err := DB.Table("channels").
Select("channels.name, channels.type"). Select("channels.name, channels.type").
Joins("join abilities on abilities.channel_id = channels.id"). Joins("join abilities on abilities.channel_id = channels.id").
Where("abilities.model = ? AND abilities.enabled = ?", modelName, true). Where("abilities.model = ? AND abilities.enabled = ?", modelName, true).
Group("channels.id"). Group("channels.id").
Scan(&channels).Error Scan(&channels).Error
return channels, err return channels, err
} }
// FindModelByNameWithRule 根据模型名称和匹配规则查找模型元数据,优先级:精确 > 前缀 > 后缀 > 包含 // FindModelByNameWithRule 根据模型名称和匹配规则查找模型元数据,优先级:精确 > 前缀 > 后缀 > 包含
func FindModelByNameWithRule(name string) (*Model, error) { func FindModelByNameWithRule(name string) (*Model, error) {
// 1. 精确匹配 // 1. 精确匹配
if m, err := GetModelByName(name); err == nil { if m, err := GetModelByName(name); err == nil {
return m, nil return m, nil
} }
// 2. 规则匹配 // 2. 规则匹配
var models []*Model var models []*Model
if err := DB.Where("name_rule <> ?", NameRuleExact).Find(&models).Error; err != nil { if err := DB.Where("name_rule <> ?", NameRuleExact).Find(&models).Error; err != nil {
return nil, err return nil, err
} }
var prefixMatch, suffixMatch, containsMatch *Model var prefixMatch, suffixMatch, containsMatch *Model
for _, m := range models { for _, m := range models {
switch m.NameRule { switch m.NameRule {
case NameRulePrefix: case NameRulePrefix:
if strings.HasPrefix(name, m.ModelName) { if strings.HasPrefix(name, m.ModelName) {
if prefixMatch == nil || len(m.ModelName) > len(prefixMatch.ModelName) { if prefixMatch == nil || len(m.ModelName) > len(prefixMatch.ModelName) {
prefixMatch = m prefixMatch = m
} }
} }
case NameRuleSuffix: case NameRuleSuffix:
if strings.HasSuffix(name, m.ModelName) { if strings.HasSuffix(name, m.ModelName) {
if suffixMatch == nil || len(m.ModelName) > len(suffixMatch.ModelName) { if suffixMatch == nil || len(m.ModelName) > len(suffixMatch.ModelName) {
suffixMatch = m suffixMatch = m
} }
} }
case NameRuleContains: case NameRuleContains:
if strings.Contains(name, m.ModelName) { if strings.Contains(name, m.ModelName) {
if containsMatch == nil || len(m.ModelName) > len(containsMatch.ModelName) { if containsMatch == nil || len(m.ModelName) > len(containsMatch.ModelName) {
containsMatch = m containsMatch = m
} }
} }
} }
} }
if prefixMatch != nil { if prefixMatch != nil {
return prefixMatch, nil return prefixMatch, nil
} }
if suffixMatch != nil { if suffixMatch != nil {
return suffixMatch, nil return suffixMatch, nil
} }
if containsMatch != nil { if containsMatch != nil {
return containsMatch, nil return containsMatch, nil
} }
return nil, gorm.ErrRecordNotFound return nil, gorm.ErrRecordNotFound
} }
// SearchModels 根据关键词和供应商搜索模型,支持分页 // SearchModels 根据关键词和供应商搜索模型,支持分页
func SearchModels(keyword string, vendor string, offset int, limit int) ([]*Model, int64, error) { func SearchModels(keyword string, vendor string, offset int, limit int) ([]*Model, int64, error) {
var models []*Model var models []*Model
db := DB.Model(&Model{}) db := DB.Model(&Model{})
if keyword != "" { if keyword != "" {
like := "%" + keyword + "%" like := "%" + keyword + "%"
db = db.Where("model_name LIKE ? OR description LIKE ? OR tags LIKE ?", like, like, like) db = db.Where("model_name LIKE ? OR description LIKE ? OR tags LIKE ?", like, like, like)
} }
if vendor != "" { if vendor != "" {
// 如果是数字,按供应商 ID 精确匹配;否则按名称模糊匹配 // 如果是数字,按供应商 ID 精确匹配;否则按名称模糊匹配
if vid, err := strconv.Atoi(vendor); err == nil { if vid, err := strconv.Atoi(vendor); err == nil {
db = db.Where("models.vendor_id = ?", vid) db = db.Where("models.vendor_id = ?", vid)
} else { } else {
db = db.Joins("JOIN vendors ON vendors.id = models.vendor_id").Where("vendors.name LIKE ?", "%"+vendor+"%") db = db.Joins("JOIN vendors ON vendors.id = models.vendor_id").Where("vendors.name LIKE ?", "%"+vendor+"%")
} }
} }
var total int64 var total int64
err := db.Count(&total).Error err := db.Count(&total).Error
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
err = db.Offset(offset).Limit(limit).Order("models.id DESC").Find(&models).Error err = db.Offset(offset).Limit(limit).Order("models.id DESC").Find(&models).Error
return models, total, err return models, total, err
} }

View File

@@ -1,11 +1,11 @@
package model package model
import ( import (
"encoding/json" "database/sql/driver"
"database/sql/driver" "encoding/json"
"one-api/common" "one-api/common"
"gorm.io/gorm" "gorm.io/gorm"
) )
// PrefillGroup 用于存储可复用的“组”信息,例如模型组、标签组、端点组等。 // PrefillGroup 用于存储可复用的“组”信息,例如模型组、标签组、端点组等。
@@ -20,107 +20,107 @@ type JSONValue json.RawMessage
// Value 实现 driver.Valuer 接口,用于数据库写入 // Value 实现 driver.Valuer 接口,用于数据库写入
func (j JSONValue) Value() (driver.Value, error) { func (j JSONValue) Value() (driver.Value, error) {
if j == nil { if j == nil {
return nil, nil return nil, nil
} }
return []byte(j), nil return []byte(j), nil
} }
// Scan 实现 sql.Scanner 接口,兼容不同驱动返回的类型 // Scan 实现 sql.Scanner 接口,兼容不同驱动返回的类型
func (j *JSONValue) Scan(value interface{}) error { func (j *JSONValue) Scan(value interface{}) error {
switch v := value.(type) { switch v := value.(type) {
case nil: case nil:
*j = nil *j = nil
return nil return nil
case []byte: case []byte:
// 拷贝底层字节,避免保留底层缓冲区 // 拷贝底层字节,避免保留底层缓冲区
b := make([]byte, len(v)) b := make([]byte, len(v))
copy(b, v) copy(b, v)
*j = JSONValue(b) *j = JSONValue(b)
return nil return nil
case string: case string:
*j = JSONValue([]byte(v)) *j = JSONValue([]byte(v))
return nil return nil
default: default:
// 其他类型尝试序列化为 JSON // 其他类型尝试序列化为 JSON
b, err := json.Marshal(v) b, err := json.Marshal(v)
if err != nil { if err != nil {
return err return err
} }
*j = JSONValue(b) *j = JSONValue(b)
return nil return nil
} }
} }
// MarshalJSON 确保在对外编码时与 json.RawMessage 行为一致 // MarshalJSON 确保在对外编码时与 json.RawMessage 行为一致
func (j JSONValue) MarshalJSON() ([]byte, error) { func (j JSONValue) MarshalJSON() ([]byte, error) {
if j == nil { if j == nil {
return []byte("null"), nil return []byte("null"), nil
} }
return j, nil return j, nil
} }
// UnmarshalJSON 确保在对外解码时与 json.RawMessage 行为一致 // UnmarshalJSON 确保在对外解码时与 json.RawMessage 行为一致
func (j *JSONValue) UnmarshalJSON(data []byte) error { func (j *JSONValue) UnmarshalJSON(data []byte) error {
if data == nil { if data == nil {
*j = nil *j = nil
return nil return nil
} }
b := make([]byte, len(data)) b := make([]byte, len(data))
copy(b, data) copy(b, data)
*j = JSONValue(b) *j = JSONValue(b)
return nil return nil
} }
type PrefillGroup struct { type PrefillGroup struct {
Id int `json:"id"` Id int `json:"id"`
Name string `json:"name" gorm:"size:64;not null;uniqueIndex:uk_prefill_name,where:deleted_at IS NULL"` Name string `json:"name" gorm:"size:64;not null;uniqueIndex:uk_prefill_name,where:deleted_at IS NULL"`
Type string `json:"type" gorm:"size:32;index;not null"` Type string `json:"type" gorm:"size:32;index;not null"`
Items JSONValue `json:"items" gorm:"type:json"` Items JSONValue `json:"items" gorm:"type:json"`
Description string `json:"description,omitempty" gorm:"type:varchar(255)"` Description string `json:"description,omitempty" gorm:"type:varchar(255)"`
CreatedTime int64 `json:"created_time" gorm:"bigint"` CreatedTime int64 `json:"created_time" gorm:"bigint"`
UpdatedTime int64 `json:"updated_time" gorm:"bigint"` UpdatedTime int64 `json:"updated_time" gorm:"bigint"`
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"` DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
} }
// Insert 新建组 // Insert 新建组
func (g *PrefillGroup) Insert() error { func (g *PrefillGroup) Insert() error {
now := common.GetTimestamp() now := common.GetTimestamp()
g.CreatedTime = now g.CreatedTime = now
g.UpdatedTime = now g.UpdatedTime = now
return DB.Create(g).Error return DB.Create(g).Error
} }
// IsPrefillGroupNameDuplicated 检查组名称是否重复(排除自身 ID // IsPrefillGroupNameDuplicated 检查组名称是否重复(排除自身 ID
func IsPrefillGroupNameDuplicated(id int, name string) (bool, error) { func IsPrefillGroupNameDuplicated(id int, name string) (bool, error) {
if name == "" { if name == "" {
return false, nil return false, nil
} }
var cnt int64 var cnt int64
err := DB.Model(&PrefillGroup{}).Where("name = ? AND id <> ?", name, id).Count(&cnt).Error err := DB.Model(&PrefillGroup{}).Where("name = ? AND id <> ?", name, id).Count(&cnt).Error
return cnt > 0, err return cnt > 0, err
} }
// Update 更新组 // Update 更新组
func (g *PrefillGroup) Update() error { func (g *PrefillGroup) Update() error {
g.UpdatedTime = common.GetTimestamp() g.UpdatedTime = common.GetTimestamp()
return DB.Save(g).Error return DB.Save(g).Error
} }
// DeleteByID 根据 ID 删除组 // DeleteByID 根据 ID 删除组
func DeletePrefillGroupByID(id int) error { func DeletePrefillGroupByID(id int) error {
return DB.Delete(&PrefillGroup{}, id).Error return DB.Delete(&PrefillGroup{}, id).Error
} }
// GetAllPrefillGroups 获取全部组,可按类型过滤(为空则返回全部) // GetAllPrefillGroups 获取全部组,可按类型过滤(为空则返回全部)
func GetAllPrefillGroups(groupType string) ([]*PrefillGroup, error) { func GetAllPrefillGroups(groupType string) ([]*PrefillGroup, error) {
var groups []*PrefillGroup var groups []*PrefillGroup
query := DB.Model(&PrefillGroup{}) query := DB.Model(&PrefillGroup{})
if groupType != "" { if groupType != "" {
query = query.Where("type = ?", groupType) query = query.Where("type = ?", groupType)
} }
if err := query.Order("updated_time DESC").Find(&groups).Error; err != nil { if err := query.Order("updated_time DESC").Find(&groups).Error; err != nil {
return nil, err return nil, err
} }
return groups, nil return groups, nil
} }

View File

@@ -1,31 +1,31 @@
package model package model
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"strings" "strings"
"one-api/common" "one-api/common"
"one-api/constant" "one-api/constant"
"one-api/setting/ratio_setting" "one-api/setting/ratio_setting"
"one-api/types" "one-api/types"
"sync" "sync"
"time" "time"
) )
type Pricing struct { type Pricing struct {
ModelName string `json:"model_name"` ModelName string `json:"model_name"`
Description string `json:"description,omitempty"` Description string `json:"description,omitempty"`
Icon string `json:"icon,omitempty"` Icon string `json:"icon,omitempty"`
Tags string `json:"tags,omitempty"` Tags string `json:"tags,omitempty"`
VendorID int `json:"vendor_id,omitempty"` VendorID int `json:"vendor_id,omitempty"`
QuotaType int `json:"quota_type"` QuotaType int `json:"quota_type"`
ModelRatio float64 `json:"model_ratio"` ModelRatio float64 `json:"model_ratio"`
ModelPrice float64 `json:"model_price"` ModelPrice float64 `json:"model_price"`
OwnerBy string `json:"owner_by"` OwnerBy string `json:"owner_by"`
CompletionRatio float64 `json:"completion_ratio"` CompletionRatio float64 `json:"completion_ratio"`
EnableGroup []string `json:"enable_groups"` EnableGroup []string `json:"enable_groups"`
SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"` SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
} }
type PricingVendor struct { type PricingVendor struct {
@@ -36,11 +36,11 @@ type PricingVendor struct {
} }
var ( var (
pricingMap []Pricing pricingMap []Pricing
vendorsList []PricingVendor vendorsList []PricingVendor
supportedEndpointMap map[string]common.EndpointInfo supportedEndpointMap map[string]common.EndpointInfo
lastGetPricingTime time.Time lastGetPricingTime time.Time
updatePricingLock sync.Mutex updatePricingLock sync.Mutex
// 缓存映射:模型名 -> 启用分组 / 计费类型 // 缓存映射:模型名 -> 启用分组 / 计费类型
modelEnableGroups = make(map[string][]string) modelEnableGroups = make(map[string][]string)
@@ -122,19 +122,19 @@ func updatePricing() {
for _, m := range prefixList { for _, m := range prefixList {
for _, pricingModel := range enableAbilities { for _, pricingModel := range enableAbilities {
if strings.HasPrefix(pricingModel.Model, m.ModelName) { if strings.HasPrefix(pricingModel.Model, m.ModelName) {
if _, exists := metaMap[pricingModel.Model]; !exists { if _, exists := metaMap[pricingModel.Model]; !exists {
metaMap[pricingModel.Model] = m metaMap[pricingModel.Model] = m
} }
} }
} }
} }
for _, m := range suffixList { for _, m := range suffixList {
for _, pricingModel := range enableAbilities { for _, pricingModel := range enableAbilities {
if strings.HasSuffix(pricingModel.Model, m.ModelName) { if strings.HasSuffix(pricingModel.Model, m.ModelName) {
if _, exists := metaMap[pricingModel.Model]; !exists { if _, exists := metaMap[pricingModel.Model]; !exists {
metaMap[pricingModel.Model] = m metaMap[pricingModel.Model] = m
} }
} }
} }
} }
for _, m := range containsList { for _, m := range containsList {
@@ -180,34 +180,34 @@ func updatePricing() {
//这里使用切片而不是Set因为一个模型可能支持多个端点类型并且第一个端点是优先使用端点 //这里使用切片而不是Set因为一个模型可能支持多个端点类型并且第一个端点是优先使用端点
modelSupportEndpointsStr := make(map[string][]string) modelSupportEndpointsStr := make(map[string][]string)
// 先根据已有能力填充原生端点 // 先根据已有能力填充原生端点
for _, ability := range enableAbilities { for _, ability := range enableAbilities {
endpoints := modelSupportEndpointsStr[ability.Model] endpoints := modelSupportEndpointsStr[ability.Model]
channelTypes := common.GetEndpointTypesByChannelType(ability.ChannelType, ability.Model) channelTypes := common.GetEndpointTypesByChannelType(ability.ChannelType, ability.Model)
for _, channelType := range channelTypes { for _, channelType := range channelTypes {
if !common.StringsContains(endpoints, string(channelType)) { if !common.StringsContains(endpoints, string(channelType)) {
endpoints = append(endpoints, string(channelType)) endpoints = append(endpoints, string(channelType))
} }
} }
modelSupportEndpointsStr[ability.Model] = endpoints modelSupportEndpointsStr[ability.Model] = endpoints
} }
// 再补充模型自定义端点 // 再补充模型自定义端点
for modelName, meta := range metaMap { for modelName, meta := range metaMap {
if strings.TrimSpace(meta.Endpoints) == "" { if strings.TrimSpace(meta.Endpoints) == "" {
continue continue
} }
var raw map[string]interface{} var raw map[string]interface{}
if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil { if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil {
endpoints := modelSupportEndpointsStr[modelName] endpoints := modelSupportEndpointsStr[modelName]
for k := range raw { for k := range raw {
if !common.StringsContains(endpoints, k) { if !common.StringsContains(endpoints, k) {
endpoints = append(endpoints, k) endpoints = append(endpoints, k)
} }
} }
modelSupportEndpointsStr[modelName] = endpoints modelSupportEndpointsStr[modelName] = endpoints
} }
} }
modelSupportEndpointTypes = make(map[string][]constant.EndpointType) modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
for model, endpoints := range modelSupportEndpointsStr { for model, endpoints := range modelSupportEndpointsStr {
@@ -217,93 +217,93 @@ func updatePricing() {
supportedEndpoints = append(supportedEndpoints, endpointType) supportedEndpoints = append(supportedEndpoints, endpointType)
} }
modelSupportEndpointTypes[model] = supportedEndpoints modelSupportEndpointTypes[model] = supportedEndpoints
} }
// 构建全局 supportedEndpointMap默认 + 自定义覆盖) // 构建全局 supportedEndpointMap默认 + 自定义覆盖)
supportedEndpointMap = make(map[string]common.EndpointInfo) supportedEndpointMap = make(map[string]common.EndpointInfo)
// 1. 默认端点 // 1. 默认端点
for _, endpoints := range modelSupportEndpointTypes { for _, endpoints := range modelSupportEndpointTypes {
for _, et := range endpoints { for _, et := range endpoints {
if info, ok := common.GetDefaultEndpointInfo(et); ok { if info, ok := common.GetDefaultEndpointInfo(et); ok {
if _, exists := supportedEndpointMap[string(et)]; !exists { if _, exists := supportedEndpointMap[string(et)]; !exists {
supportedEndpointMap[string(et)] = info supportedEndpointMap[string(et)] = info
} }
} }
} }
} }
// 2. 自定义端点models 表)覆盖默认 // 2. 自定义端点models 表)覆盖默认
for _, meta := range metaMap { for _, meta := range metaMap {
if strings.TrimSpace(meta.Endpoints) == "" { if strings.TrimSpace(meta.Endpoints) == "" {
continue continue
} }
var raw map[string]interface{} var raw map[string]interface{}
if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil { if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil {
for k, v := range raw { for k, v := range raw {
switch val := v.(type) { switch val := v.(type) {
case string: case string:
supportedEndpointMap[k] = common.EndpointInfo{Path: val, Method: "POST"} supportedEndpointMap[k] = common.EndpointInfo{Path: val, Method: "POST"}
case map[string]interface{}: case map[string]interface{}:
ep := common.EndpointInfo{Method: "POST"} ep := common.EndpointInfo{Method: "POST"}
if p, ok := val["path"].(string); ok { if p, ok := val["path"].(string); ok {
ep.Path = p ep.Path = p
} }
if m, ok := val["method"].(string); ok { if m, ok := val["method"].(string); ok {
ep.Method = strings.ToUpper(m) ep.Method = strings.ToUpper(m)
} }
supportedEndpointMap[k] = ep supportedEndpointMap[k] = ep
default: default:
// ignore unsupported types // ignore unsupported types
} }
} }
} }
} }
pricingMap = make([]Pricing, 0) pricingMap = make([]Pricing, 0)
for model, groups := range modelGroupsMap { for model, groups := range modelGroupsMap {
pricing := Pricing{ pricing := Pricing{
ModelName: model, ModelName: model,
EnableGroup: groups.Items(), EnableGroup: groups.Items(),
SupportedEndpointTypes: modelSupportEndpointTypes[model], SupportedEndpointTypes: modelSupportEndpointTypes[model],
} }
// 补充模型元数据(描述、标签、供应商、状态) // 补充模型元数据(描述、标签、供应商、状态)
if meta, ok := metaMap[model]; ok { if meta, ok := metaMap[model]; ok {
// 若模型被禁用(status!=1),则直接跳过,不返回给前端 // 若模型被禁用(status!=1),则直接跳过,不返回给前端
if meta.Status != 1 { if meta.Status != 1 {
continue continue
} }
pricing.Description = meta.Description pricing.Description = meta.Description
pricing.Icon = meta.Icon pricing.Icon = meta.Icon
pricing.Tags = meta.Tags pricing.Tags = meta.Tags
pricing.VendorID = meta.VendorID pricing.VendorID = meta.VendorID
} }
modelPrice, findPrice := ratio_setting.GetModelPrice(model, false) modelPrice, findPrice := ratio_setting.GetModelPrice(model, false)
if findPrice { if findPrice {
pricing.ModelPrice = modelPrice pricing.ModelPrice = modelPrice
pricing.QuotaType = 1 pricing.QuotaType = 1
} else { } else {
modelRatio, _, _ := ratio_setting.GetModelRatio(model) modelRatio, _, _ := ratio_setting.GetModelRatio(model)
pricing.ModelRatio = modelRatio pricing.ModelRatio = modelRatio
pricing.CompletionRatio = ratio_setting.GetCompletionRatio(model) pricing.CompletionRatio = ratio_setting.GetCompletionRatio(model)
pricing.QuotaType = 0 pricing.QuotaType = 0
} }
pricingMap = append(pricingMap, pricing) pricingMap = append(pricingMap, pricing)
} }
// 刷新缓存映射,供高并发快速查询 // 刷新缓存映射,供高并发快速查询
modelEnableGroupsLock.Lock() modelEnableGroupsLock.Lock()
modelEnableGroups = make(map[string][]string) modelEnableGroups = make(map[string][]string)
modelQuotaTypeMap = make(map[string]int) modelQuotaTypeMap = make(map[string]int)
for _, p := range pricingMap { for _, p := range pricingMap {
modelEnableGroups[p.ModelName] = p.EnableGroup modelEnableGroups[p.ModelName] = p.EnableGroup
modelQuotaTypeMap[p.ModelName] = p.QuotaType modelQuotaTypeMap[p.ModelName] = p.QuotaType
} }
modelEnableGroupsLock.Unlock() modelEnableGroupsLock.Unlock()
lastGetPricingTime = time.Now() lastGetPricingTime = time.Now()
} }
// GetSupportedEndpointMap 返回全局端点到路径的映射 // GetSupportedEndpointMap 返回全局端点到路径的映射
func GetSupportedEndpointMap() map[string]common.EndpointInfo { func GetSupportedEndpointMap() map[string]common.EndpointInfo {
return supportedEndpointMap return supportedEndpointMap
} }

View File

@@ -4,11 +4,11 @@ package model
// 该方法用于需要最新数据的内部管理 API // 该方法用于需要最新数据的内部管理 API
// 因此会绕过默认的 1 分钟延迟刷新。 // 因此会绕过默认的 1 分钟延迟刷新。
func RefreshPricing() { func RefreshPricing() {
updatePricingLock.Lock() updatePricingLock.Lock()
defer updatePricingLock.Unlock() defer updatePricingLock.Unlock()
modelSupportEndpointsLock.Lock() modelSupportEndpointsLock.Lock()
defer modelSupportEndpointsLock.Unlock() defer modelSupportEndpointsLock.Unlock()
updatePricing() updatePricing()
} }

View File

@@ -1,9 +1,9 @@
package model package model
import ( import (
"one-api/common" "one-api/common"
"gorm.io/gorm" "gorm.io/gorm"
) )
// Vendor 用于存储供应商信息,供模型引用 // Vendor 用于存储供应商信息,供模型引用
@@ -13,76 +13,76 @@ import (
// 本表同样遵循 3NF 设计范式 // 本表同样遵循 3NF 设计范式
type Vendor struct { type Vendor struct {
Id int `json:"id"` Id int `json:"id"`
Name string `json:"name" gorm:"size:128;not null;uniqueIndex:uk_vendor_name,priority:1"` Name string `json:"name" gorm:"size:128;not null;uniqueIndex:uk_vendor_name,priority:1"`
Description string `json:"description,omitempty" gorm:"type:text"` Description string `json:"description,omitempty" gorm:"type:text"`
Icon string `json:"icon,omitempty" gorm:"type:varchar(128)"` Icon string `json:"icon,omitempty" gorm:"type:varchar(128)"`
Status int `json:"status" gorm:"default:1"` Status int `json:"status" gorm:"default:1"`
CreatedTime int64 `json:"created_time" gorm:"bigint"` CreatedTime int64 `json:"created_time" gorm:"bigint"`
UpdatedTime int64 `json:"updated_time" gorm:"bigint"` UpdatedTime int64 `json:"updated_time" gorm:"bigint"`
DeletedAt gorm.DeletedAt `json:"-" gorm:"index;uniqueIndex:uk_vendor_name,priority:2"` DeletedAt gorm.DeletedAt `json:"-" gorm:"index;uniqueIndex:uk_vendor_name,priority:2"`
} }
// Insert 创建新的供应商记录 // Insert 创建新的供应商记录
func (v *Vendor) Insert() error { func (v *Vendor) Insert() error {
now := common.GetTimestamp() now := common.GetTimestamp()
v.CreatedTime = now v.CreatedTime = now
v.UpdatedTime = now v.UpdatedTime = now
return DB.Create(v).Error return DB.Create(v).Error
} }
// IsVendorNameDuplicated 检查供应商名称是否重复(排除自身 ID // IsVendorNameDuplicated 检查供应商名称是否重复(排除自身 ID
func IsVendorNameDuplicated(id int, name string) (bool, error) { func IsVendorNameDuplicated(id int, name string) (bool, error) {
if name == "" { if name == "" {
return false, nil return false, nil
} }
var cnt int64 var cnt int64
err := DB.Model(&Vendor{}).Where("name = ? AND id <> ?", name, id).Count(&cnt).Error err := DB.Model(&Vendor{}).Where("name = ? AND id <> ?", name, id).Count(&cnt).Error
return cnt > 0, err return cnt > 0, err
} }
// Update 更新供应商记录 // Update 更新供应商记录
func (v *Vendor) Update() error { func (v *Vendor) Update() error {
v.UpdatedTime = common.GetTimestamp() v.UpdatedTime = common.GetTimestamp()
return DB.Save(v).Error return DB.Save(v).Error
} }
// Delete 软删除供应商 // Delete 软删除供应商
func (v *Vendor) Delete() error { func (v *Vendor) Delete() error {
return DB.Delete(v).Error return DB.Delete(v).Error
} }
// GetVendorByID 根据 ID 获取供应商 // GetVendorByID 根据 ID 获取供应商
func GetVendorByID(id int) (*Vendor, error) { func GetVendorByID(id int) (*Vendor, error) {
var v Vendor var v Vendor
err := DB.First(&v, id).Error err := DB.First(&v, id).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &v, nil return &v, nil
} }
// GetAllVendors 获取全部供应商(分页) // GetAllVendors 获取全部供应商(分页)
func GetAllVendors(offset int, limit int) ([]*Vendor, error) { func GetAllVendors(offset int, limit int) ([]*Vendor, error) {
var vendors []*Vendor var vendors []*Vendor
err := DB.Offset(offset).Limit(limit).Find(&vendors).Error err := DB.Offset(offset).Limit(limit).Find(&vendors).Error
return vendors, err return vendors, err
} }
// SearchVendors 按关键字搜索供应商 // SearchVendors 按关键字搜索供应商
func SearchVendors(keyword string, offset int, limit int) ([]*Vendor, int64, error) { func SearchVendors(keyword string, offset int, limit int) ([]*Vendor, int64, error) {
db := DB.Model(&Vendor{}) db := DB.Model(&Vendor{})
if keyword != "" { if keyword != "" {
like := "%" + keyword + "%" like := "%" + keyword + "%"
db = db.Where("name LIKE ? OR description LIKE ?", like, like) db = db.Where("name LIKE ? OR description LIKE ?", like, like)
} }
var total int64 var total int64
if err := db.Count(&total).Error; err != nil { if err := db.Count(&total).Error; err != nil {
return nil, 0, err return nil, 0, err
} }
var vendors []*Vendor var vendors []*Vendor
if err := db.Offset(offset).Limit(limit).Order("id DESC").Find(&vendors).Error; err != nil { if err := db.Offset(offset).Limit(limit).Order("id DESC").Find(&vendors).Error; err != nil {
return nil, 0, err return nil, 0, err
} }
return vendors, total, nil return vendors, total, nil
} }