diff --git a/common/endpoint_defaults.go b/common/endpoint_defaults.go new file mode 100644 index 00000000..1dfe1dc9 --- /dev/null +++ b/common/endpoint_defaults.go @@ -0,0 +1,32 @@ +package common + +import "one-api/constant" + +// EndpointInfo 描述单个端点的默认请求信息 +// path: 上游路径 +// method: HTTP 请求方式,例如 POST/GET +// 目前均为 POST,后续可扩展 +// +// json 标签用于直接序列化到 API 输出 +// 例如:{"path":"/v1/chat/completions","method":"POST"} + +type EndpointInfo struct { + Path string `json:"path"` + Method string `json:"method"` +} + +// defaultEndpointInfoMap 保存内置端点的默认 Path 与 Method +var defaultEndpointInfoMap = map[constant.EndpointType]EndpointInfo{ + constant.EndpointTypeOpenAI: {Path: "/v1/chat/completions", Method: "POST"}, + constant.EndpointTypeOpenAIResponse: {Path: "/v1/responses", Method: "POST"}, + constant.EndpointTypeAnthropic: {Path: "/v1/messages", Method: "POST"}, + constant.EndpointTypeGemini: {Path: "/v1beta/models/{model}:generateContent", Method: "POST"}, + constant.EndpointTypeJinaRerank: {Path: "/rerank", Method: "POST"}, + constant.EndpointTypeImageGeneration: {Path: "/v1/images/generations", Method: "POST"}, +} + +// GetDefaultEndpointInfo 返回指定端点类型的默认信息以及是否存在 +func GetDefaultEndpointInfo(et constant.EndpointType) (EndpointInfo, bool) { + info, ok := defaultEndpointInfoMap[et] + return info, ok +} diff --git a/controller/missing_models.go b/controller/missing_models.go new file mode 100644 index 00000000..a3409e29 --- /dev/null +++ b/controller/missing_models.go @@ -0,0 +1,27 @@ +package controller + +import ( + "net/http" + "one-api/model" + + "github.com/gin-gonic/gin" +) + +// GetMissingModels returns the list of model names that are referenced by channels +// but do not have corresponding records in the models meta table. +// This helps administrators quickly discover models that need configuration. +func GetMissingModels(c *gin.Context) { + missing, err := model.GetMissingModels() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "data": missing, + }) +} diff --git a/controller/model_meta.go b/controller/model_meta.go new file mode 100644 index 00000000..090ea3c1 --- /dev/null +++ b/controller/model_meta.go @@ -0,0 +1,178 @@ +package controller + +import ( + "encoding/json" + "strconv" + + "one-api/common" + "one-api/model" + + "github.com/gin-gonic/gin" +) + +// GetAllModelsMeta 获取模型列表(分页) +func GetAllModelsMeta(c *gin.Context) { + + pageInfo := common.GetPageQuery(c) + modelsMeta, err := model.GetAllModels(pageInfo.GetStartIdx(), pageInfo.GetPageSize()) + if err != nil { + common.ApiError(c, err) + return + } + // 填充附加字段 + for _, m := range modelsMeta { + fillModelExtra(m) + } + var total int64 + model.DB.Model(&model.Model{}).Count(&total) + + // 统计供应商计数(全部数据,不受分页影响) + vendorCounts, _ := model.GetVendorModelCounts() + + pageInfo.SetTotal(int(total)) + pageInfo.SetItems(modelsMeta) + common.ApiSuccess(c, gin.H{ + "items": modelsMeta, + "total": total, + "page": pageInfo.GetPage(), + "page_size": pageInfo.GetPageSize(), + "vendor_counts": vendorCounts, + }) +} + +// SearchModelsMeta 搜索模型列表 +func SearchModelsMeta(c *gin.Context) { + + keyword := c.Query("keyword") + vendor := c.Query("vendor") + pageInfo := common.GetPageQuery(c) + + modelsMeta, total, err := model.SearchModels(keyword, vendor, pageInfo.GetStartIdx(), pageInfo.GetPageSize()) + if err != nil { + common.ApiError(c, err) + return + } + for _, m := range modelsMeta { + fillModelExtra(m) + } + pageInfo.SetTotal(int(total)) + pageInfo.SetItems(modelsMeta) + common.ApiSuccess(c, pageInfo) +} + +// GetModelMeta 根据 ID 获取单条模型信息 +func GetModelMeta(c *gin.Context) { + idStr := c.Param("id") + id, err := strconv.Atoi(idStr) + if err != nil { + common.ApiError(c, err) + return + } + var m model.Model + if err := model.DB.First(&m, id).Error; err != nil { + common.ApiError(c, err) + return + } + fillModelExtra(&m) + common.ApiSuccess(c, &m) +} + +// CreateModelMeta 新建模型 +func CreateModelMeta(c *gin.Context) { + var m model.Model + if err := c.ShouldBindJSON(&m); err != nil { + common.ApiError(c, err) + return + } + if m.ModelName == "" { + common.ApiErrorMsg(c, "模型名称不能为空") + return + } + // 名称冲突检查 + if dup, err := model.IsModelNameDuplicated(0, m.ModelName); err != nil { + common.ApiError(c, err) + return + } else if dup { + common.ApiErrorMsg(c, "模型名称已存在") + return + } + + if err := m.Insert(); err != nil { + common.ApiError(c, err) + return + } + model.RefreshPricing() + common.ApiSuccess(c, &m) +} + +// UpdateModelMeta 更新模型 +func UpdateModelMeta(c *gin.Context) { + statusOnly := c.Query("status_only") == "true" + + var m model.Model + if err := c.ShouldBindJSON(&m); err != nil { + common.ApiError(c, err) + return + } + if m.Id == 0 { + common.ApiErrorMsg(c, "缺少模型 ID") + return + } + + if statusOnly { + // 只更新状态,防止误清空其他字段 + if err := model.DB.Model(&model.Model{}).Where("id = ?", m.Id).Update("status", m.Status).Error; err != nil { + common.ApiError(c, err) + return + } + } else { + // 名称冲突检查 + if dup, err := model.IsModelNameDuplicated(m.Id, m.ModelName); err != nil { + common.ApiError(c, err) + return + } else if dup { + common.ApiErrorMsg(c, "模型名称已存在") + return + } + + if err := m.Update(); err != nil { + common.ApiError(c, err) + return + } + } + model.RefreshPricing() + common.ApiSuccess(c, &m) +} + +// DeleteModelMeta 删除模型 +func DeleteModelMeta(c *gin.Context) { + idStr := c.Param("id") + id, err := strconv.Atoi(idStr) + if err != nil { + common.ApiError(c, err) + return + } + if err := model.DB.Delete(&model.Model{}, id).Error; err != nil { + common.ApiError(c, err) + return + } + model.RefreshPricing() + common.ApiSuccess(c, nil) +} + +// 辅助函数:填充 Endpoints 和 BoundChannels 和 EnableGroups +func fillModelExtra(m *model.Model) { + if m.Endpoints == "" { + eps := model.GetModelSupportEndpointTypes(m.ModelName) + if b, err := json.Marshal(eps); err == nil { + m.Endpoints = string(b) + } + } + if channels, err := model.GetBoundChannels(m.ModelName); err == nil { + m.BoundChannels = channels + } + // 填充启用分组 + m.EnableGroups = model.GetModelEnableGroups(m.ModelName) + // 填充计费类型 + m.QuotaType = model.GetModelQuotaType(m.ModelName) +} diff --git a/controller/prefill_group.go b/controller/prefill_group.go new file mode 100644 index 00000000..4e29379b --- /dev/null +++ b/controller/prefill_group.go @@ -0,0 +1,90 @@ +package controller + +import ( + "strconv" + + "one-api/common" + "one-api/model" + + "github.com/gin-gonic/gin" +) + +// GetPrefillGroups 获取预填组列表,可通过 ?type=xxx 过滤 +func GetPrefillGroups(c *gin.Context) { + groupType := c.Query("type") + groups, err := model.GetAllPrefillGroups(groupType) + if err != nil { + common.ApiError(c, err) + return + } + common.ApiSuccess(c, groups) +} + +// CreatePrefillGroup 创建新的预填组 +func CreatePrefillGroup(c *gin.Context) { + var g model.PrefillGroup + if err := c.ShouldBindJSON(&g); err != nil { + common.ApiError(c, err) + return + } + if g.Name == "" || g.Type == "" { + common.ApiErrorMsg(c, "组名称和类型不能为空") + return + } + // 创建前检查名称 + if dup, err := model.IsPrefillGroupNameDuplicated(0, g.Name); err != nil { + common.ApiError(c, err) + return + } else if dup { + common.ApiErrorMsg(c, "组名称已存在") + return + } + + if err := g.Insert(); err != nil { + common.ApiError(c, err) + return + } + common.ApiSuccess(c, &g) +} + +// UpdatePrefillGroup 更新预填组 +func UpdatePrefillGroup(c *gin.Context) { + var g model.PrefillGroup + if err := c.ShouldBindJSON(&g); err != nil { + common.ApiError(c, err) + return + } + if g.Id == 0 { + common.ApiErrorMsg(c, "缺少组 ID") + return + } + // 名称冲突检查 + if dup, err := model.IsPrefillGroupNameDuplicated(g.Id, g.Name); err != nil { + common.ApiError(c, err) + return + } else if dup { + common.ApiErrorMsg(c, "组名称已存在") + return + } + + if err := g.Update(); err != nil { + common.ApiError(c, err) + return + } + common.ApiSuccess(c, &g) +} + +// DeletePrefillGroup 删除预填组 +func DeletePrefillGroup(c *gin.Context) { + idStr := c.Param("id") + id, err := strconv.Atoi(idStr) + if err != nil { + common.ApiError(c, err) + return + } + if err := model.DeletePrefillGroupByID(id); err != nil { + common.ApiError(c, err) + return + } + common.ApiSuccess(c, nil) +} diff --git a/controller/pricing.go b/controller/pricing.go index f27336b7..e1719cf3 100644 --- a/controller/pricing.go +++ b/controller/pricing.go @@ -41,9 +41,11 @@ func GetPricing(c *gin.Context) { c.JSON(200, gin.H{ "success": true, "data": pricing, - "group_ratio": groupRatio, - "usable_group": usableGroup, - }) + "vendors": model.GetVendors(), + "group_ratio": groupRatio, + "usable_group": usableGroup, + "supported_endpoint": model.GetSupportedEndpointMap(), + }) } func ResetModelRatio(c *gin.Context) { diff --git a/controller/vendor_meta.go b/controller/vendor_meta.go new file mode 100644 index 00000000..28664dd6 --- /dev/null +++ b/controller/vendor_meta.go @@ -0,0 +1,124 @@ +package controller + +import ( + "strconv" + + "one-api/common" + "one-api/model" + + "github.com/gin-gonic/gin" +) + +// GetAllVendors 获取供应商列表(分页) +func GetAllVendors(c *gin.Context) { + pageInfo := common.GetPageQuery(c) + vendors, err := model.GetAllVendors(pageInfo.GetStartIdx(), pageInfo.GetPageSize()) + if err != nil { + common.ApiError(c, err) + return + } + var total int64 + model.DB.Model(&model.Vendor{}).Count(&total) + pageInfo.SetTotal(int(total)) + pageInfo.SetItems(vendors) + common.ApiSuccess(c, pageInfo) +} + +// SearchVendors 搜索供应商 +func SearchVendors(c *gin.Context) { + keyword := c.Query("keyword") + pageInfo := common.GetPageQuery(c) + vendors, total, err := model.SearchVendors(keyword, pageInfo.GetStartIdx(), pageInfo.GetPageSize()) + if err != nil { + common.ApiError(c, err) + return + } + pageInfo.SetTotal(int(total)) + pageInfo.SetItems(vendors) + common.ApiSuccess(c, pageInfo) +} + +// GetVendorMeta 根据 ID 获取供应商 +func GetVendorMeta(c *gin.Context) { + idStr := c.Param("id") + id, err := strconv.Atoi(idStr) + if err != nil { + common.ApiError(c, err) + return + } + v, err := model.GetVendorByID(id) + if err != nil { + common.ApiError(c, err) + return + } + common.ApiSuccess(c, v) +} + +// CreateVendorMeta 新建供应商 +func CreateVendorMeta(c *gin.Context) { + var v model.Vendor + if err := c.ShouldBindJSON(&v); err != nil { + common.ApiError(c, err) + return + } + if v.Name == "" { + common.ApiErrorMsg(c, "供应商名称不能为空") + return + } + // 创建前先检查名称 + if dup, err := model.IsVendorNameDuplicated(0, v.Name); err != nil { + common.ApiError(c, err) + return + } else if dup { + common.ApiErrorMsg(c, "供应商名称已存在") + return + } + + if err := v.Insert(); err != nil { + common.ApiError(c, err) + return + } + common.ApiSuccess(c, &v) +} + +// UpdateVendorMeta 更新供应商 +func UpdateVendorMeta(c *gin.Context) { + var v model.Vendor + if err := c.ShouldBindJSON(&v); err != nil { + common.ApiError(c, err) + return + } + if v.Id == 0 { + common.ApiErrorMsg(c, "缺少供应商 ID") + return + } + // 名称冲突检查 + if dup, err := model.IsVendorNameDuplicated(v.Id, v.Name); err != nil { + common.ApiError(c, err) + return + } else if dup { + common.ApiErrorMsg(c, "供应商名称已存在") + return + } + + if err := v.Update(); err != nil { + common.ApiError(c, err) + return + } + common.ApiSuccess(c, &v) +} + +// DeleteVendorMeta 删除供应商 +func DeleteVendorMeta(c *gin.Context) { + idStr := c.Param("id") + id, err := strconv.Atoi(idStr) + if err != nil { + common.ApiError(c, err) + return + } + if err := model.DB.Delete(&model.Vendor{}, id).Error; err != nil { + common.ApiError(c, err) + return + } + common.ApiSuccess(c, nil) +} \ No newline at end of file diff --git a/model/main.go b/model/main.go index 38dd2aee..b93f01a2 100644 --- a/model/main.go +++ b/model/main.go @@ -250,6 +250,9 @@ func migrateDB() error { &TopUp{}, &QuotaData{}, &Task{}, + &Model{}, + &Vendor{}, + &PrefillGroup{}, &Setup{}, &TwoFA{}, &TwoFABackupCode{}, @@ -278,6 +281,9 @@ func migrateDBFast() error { {&TopUp{}, "TopUp"}, {&QuotaData{}, "QuotaData"}, {&Task{}, "Task"}, + {&Model{}, "Model"}, + {&Vendor{}, "Vendor"}, + {&PrefillGroup{}, "PrefillGroup"}, {&Setup{}, "Setup"}, {&TwoFA{}, "TwoFA"}, {&TwoFABackupCode{}, "TwoFABackupCode"}, diff --git a/model/missing_models.go b/model/missing_models.go new file mode 100644 index 00000000..57269f5f --- /dev/null +++ b/model/missing_models.go @@ -0,0 +1,30 @@ +package model + +// GetMissingModels returns model names that are referenced in the system +func GetMissingModels() ([]string, error) { + // 1. 获取所有已启用模型(去重) + models := GetEnabledModels() + if len(models) == 0 { + return []string{}, nil + } + + // 2. 查询已有的元数据模型名 + var existing []string + if err := DB.Model(&Model{}).Where("model_name IN ?", models).Pluck("model_name", &existing).Error; err != nil { + return nil, err + } + + existingSet := make(map[string]struct{}, len(existing)) + for _, e := range existing { + existingSet[e] = struct{}{} + } + + // 3. 收集缺失模型 + var missing []string + for _, name := range models { + if _, ok := existingSet[name]; !ok { + missing = append(missing, name) + } + } + return missing, nil +} diff --git a/model/model_extra.go b/model/model_extra.go new file mode 100644 index 00000000..6ade6ff0 --- /dev/null +++ b/model/model_extra.go @@ -0,0 +1,34 @@ +package model + +// GetModelEnableGroups 返回指定模型名称可用的用户分组列表。 +// 使用在 updatePricing() 中维护的缓存映射,O(1) 读取,适合高并发场景。 +func GetModelEnableGroups(modelName string) []string { + // 确保缓存最新 + GetPricing() + + if modelName == "" { + return make([]string, 0) + } + + modelEnableGroupsLock.RLock() + groups, ok := modelEnableGroups[modelName] + modelEnableGroupsLock.RUnlock() + if !ok { + return make([]string, 0) + } + return groups +} + +// GetModelQuotaType 返回指定模型的计费类型(quota_type)。 +// 同样使用缓存映射,避免每次遍历定价切片。 +func GetModelQuotaType(modelName string) int { + GetPricing() + + modelEnableGroupsLock.RLock() + quota, ok := modelQuotaTypeMap[modelName] + modelEnableGroupsLock.RUnlock() + if !ok { + return 0 + } + return quota +} diff --git a/model/model_meta.go b/model/model_meta.go new file mode 100644 index 00000000..5ccd80c5 --- /dev/null +++ b/model/model_meta.go @@ -0,0 +1,204 @@ +package model + +import ( + "one-api/common" + "strconv" + "strings" + + "gorm.io/gorm" +) + +// Model 用于存储模型的元数据,例如描述、标签等 +// ModelName 字段具有唯一性约束,确保每个模型只会出现一次 +// Tags 字段使用逗号分隔的字符串保存标签集合,后期可根据需要扩展为 JSON 类型 +// Status: 1 表示启用,0 表示禁用,保留以便后续功能扩展 +// CreatedTime 和 UpdatedTime 使用 Unix 时间戳(秒)保存方便跨数据库移植 +// DeletedAt 采用 GORM 的软删除特性,便于后续数据恢复 +// +// 该表设计遵循第三范式(3NF): +// 1. 每一列都与主键(Id 或 ModelName)直接相关 +// 2. 不存在部分依赖(ModelName 是唯一键) +// 3. 不存在传递依赖(描述、标签等都依赖于 ModelName,而非依赖于其他非主键列) +// 这样既保证了数据一致性,也方便后期扩展 + +// 模型名称匹配规则 +const ( + NameRuleExact = iota // 0 精确匹配 + NameRulePrefix // 1 前缀匹配 + NameRuleContains // 2 包含匹配 + NameRuleSuffix // 3 后缀匹配 +) + +type BoundChannel struct { + Name string `json:"name"` + Type int `json:"type"` +} + +type Model struct { + Id int `json:"id"` + ModelName string `json:"model_name" gorm:"size:128;not null;uniqueIndex:uk_model_name,where:deleted_at IS NULL"` + Description string `json:"description,omitempty" gorm:"type:text"` + Tags string `json:"tags,omitempty" gorm:"type:varchar(255)"` + VendorID int `json:"vendor_id,omitempty" gorm:"index"` + Endpoints string `json:"endpoints,omitempty" gorm:"type:text"` + Status int `json:"status" gorm:"default:1"` + CreatedTime int64 `json:"created_time" gorm:"bigint"` + UpdatedTime int64 `json:"updated_time" gorm:"bigint"` + DeletedAt gorm.DeletedAt `json:"-" gorm:"index"` + + BoundChannels []BoundChannel `json:"bound_channels,omitempty" gorm:"-"` + EnableGroups []string `json:"enable_groups,omitempty" gorm:"-"` + QuotaType int `json:"quota_type" gorm:"-"` + NameRule int `json:"name_rule" gorm:"default:0"` +} + +// Insert 创建新的模型元数据记录 +func (mi *Model) Insert() error { + now := common.GetTimestamp() + mi.CreatedTime = now + mi.UpdatedTime = now + return DB.Create(mi).Error +} + +// IsModelNameDuplicated 检查模型名称是否重复(排除自身 ID) +func IsModelNameDuplicated(id int, name string) (bool, error) { + if name == "" { + return false, nil + } + var cnt int64 + err := DB.Model(&Model{}).Where("model_name = ? AND id <> ?", name, id).Count(&cnt).Error + return cnt > 0, err +} + +// Update 更新现有模型记录 +func (mi *Model) Update() error { + // 仅更新需要变更的字段,避免覆盖 CreatedTime + mi.UpdatedTime = common.GetTimestamp() + + // 排除 created_time,其余字段自动更新,避免新增字段时需要维护列表 + return DB.Model(&Model{}).Where("id = ?", mi.Id).Omit("created_time").Updates(mi).Error +} + +// Delete 软删除模型记录 +func (mi *Model) Delete() error { + return DB.Delete(mi).Error +} + +// GetModelByName 根据模型名称查询元数据 +func GetModelByName(name string) (*Model, error) { + var mi Model + err := DB.Where("model_name = ?", name).First(&mi).Error + if err != nil { + return nil, err + } + return &mi, nil +} + +// GetVendorModelCounts 统计每个供应商下模型数量(不受分页影响) +func GetVendorModelCounts() (map[int64]int64, error) { + var stats []struct { + VendorID int64 + Count int64 + } + if err := DB.Model(&Model{}). + Select("vendor_id as vendor_id, count(*) as count"). + Group("vendor_id"). + Scan(&stats).Error; err != nil { + return nil, err + } + m := make(map[int64]int64, len(stats)) + for _, s := range stats { + m[s.VendorID] = s.Count + } + return m, nil +} + +// GetAllModels 分页获取所有模型元数据 +func GetAllModels(offset int, limit int) ([]*Model, error) { + var models []*Model + err := DB.Offset(offset).Limit(limit).Find(&models).Error + return models, err +} + +// GetBoundChannels 查询支持该模型的渠道(名称+类型) +func GetBoundChannels(modelName string) ([]BoundChannel, error) { + var channels []BoundChannel + err := DB.Table("channels"). + Select("channels.name, channels.type"). + Joins("join abilities on abilities.channel_id = channels.id"). + Where("abilities.model = ? AND abilities.enabled = ?", modelName, true). + Group("channels.id"). + Scan(&channels).Error + return channels, err +} + +// FindModelByNameWithRule 根据模型名称和匹配规则查找模型元数据,优先级:精确 > 前缀 > 后缀 > 包含 +func FindModelByNameWithRule(name string) (*Model, error) { + // 1. 精确匹配 + if m, err := GetModelByName(name); err == nil { + return m, nil + } + // 2. 规则匹配 + var models []*Model + if err := DB.Where("name_rule <> ?", NameRuleExact).Find(&models).Error; err != nil { + return nil, err + } + var prefixMatch, suffixMatch, containsMatch *Model + for _, m := range models { + switch m.NameRule { + case NameRulePrefix: + if strings.HasPrefix(name, m.ModelName) { + if prefixMatch == nil || len(m.ModelName) > len(prefixMatch.ModelName) { + prefixMatch = m + } + } + case NameRuleSuffix: + if strings.HasSuffix(name, m.ModelName) { + if suffixMatch == nil || len(m.ModelName) > len(suffixMatch.ModelName) { + suffixMatch = m + } + } + case NameRuleContains: + if strings.Contains(name, m.ModelName) { + if containsMatch == nil || len(m.ModelName) > len(containsMatch.ModelName) { + containsMatch = m + } + } + } + } + if prefixMatch != nil { + return prefixMatch, nil + } + if suffixMatch != nil { + return suffixMatch, nil + } + if containsMatch != nil { + return containsMatch, nil + } + return nil, gorm.ErrRecordNotFound +} + +// SearchModels 根据关键词和供应商搜索模型,支持分页 +func SearchModels(keyword string, vendor string, offset int, limit int) ([]*Model, int64, error) { + var models []*Model + db := DB.Model(&Model{}) + if keyword != "" { + like := "%" + keyword + "%" + db = db.Where("model_name LIKE ? OR description LIKE ? OR tags LIKE ?", like, like, like) + } + if vendor != "" { + // 如果是数字,按供应商 ID 精确匹配;否则按名称模糊匹配 + if vid, err := strconv.Atoi(vendor); err == nil { + db = db.Where("models.vendor_id = ?", vid) + } else { + db = db.Joins("JOIN vendors ON vendors.id = models.vendor_id").Where("vendors.name LIKE ?", "%"+vendor+"%") + } + } + var total int64 + err := db.Count(&total).Error + if err != nil { + return nil, 0, err + } + err = db.Offset(offset).Limit(limit).Order("models.id DESC").Find(&models).Error + return models, total, err +} diff --git a/model/prefill_group.go b/model/prefill_group.go new file mode 100644 index 00000000..d9e92faa --- /dev/null +++ b/model/prefill_group.go @@ -0,0 +1,126 @@ +package model + +import ( + "encoding/json" + "database/sql/driver" + "one-api/common" + + "gorm.io/gorm" +) + +// PrefillGroup 用于存储可复用的“组”信息,例如模型组、标签组、端点组等。 +// Name 字段保持唯一,用于在前端下拉框中展示。 +// Type 字段用于区分组的类别,可选值如:model、tag、endpoint。 +// Items 字段使用 JSON 数组保存对应类型的字符串集合,示例: +// ["gpt-4o", "gpt-3.5-turbo"] +// 设计遵循 3NF,避免冗余,提供灵活扩展能力。 + +// JSONValue 基于 json.RawMessage 实现,支持从数据库的 []byte 和 string 两种类型读取 +type JSONValue json.RawMessage + +// Value 实现 driver.Valuer 接口,用于数据库写入 +func (j JSONValue) Value() (driver.Value, error) { + if j == nil { + return nil, nil + } + return []byte(j), nil +} + +// Scan 实现 sql.Scanner 接口,兼容不同驱动返回的类型 +func (j *JSONValue) Scan(value interface{}) error { + switch v := value.(type) { + case nil: + *j = nil + return nil + case []byte: + // 拷贝底层字节,避免保留底层缓冲区 + b := make([]byte, len(v)) + copy(b, v) + *j = JSONValue(b) + return nil + case string: + *j = JSONValue([]byte(v)) + return nil + default: + // 其他类型尝试序列化为 JSON + b, err := json.Marshal(v) + if err != nil { + return err + } + *j = JSONValue(b) + return nil + } +} + +// MarshalJSON 确保在对外编码时与 json.RawMessage 行为一致 +func (j JSONValue) MarshalJSON() ([]byte, error) { + if j == nil { + return []byte("null"), nil + } + return j, nil +} + +// UnmarshalJSON 确保在对外解码时与 json.RawMessage 行为一致 +func (j *JSONValue) UnmarshalJSON(data []byte) error { + if data == nil { + *j = nil + return nil + } + b := make([]byte, len(data)) + copy(b, data) + *j = JSONValue(b) + return nil +} + +type PrefillGroup struct { + Id int `json:"id"` + Name string `json:"name" gorm:"size:64;not null;uniqueIndex:uk_prefill_name,where:deleted_at IS NULL"` + Type string `json:"type" gorm:"size:32;index;not null"` + Items JSONValue `json:"items" gorm:"type:json"` + Description string `json:"description,omitempty" gorm:"type:varchar(255)"` + CreatedTime int64 `json:"created_time" gorm:"bigint"` + UpdatedTime int64 `json:"updated_time" gorm:"bigint"` + DeletedAt gorm.DeletedAt `json:"-" gorm:"index"` +} + +// Insert 新建组 +func (g *PrefillGroup) Insert() error { + now := common.GetTimestamp() + g.CreatedTime = now + g.UpdatedTime = now + return DB.Create(g).Error +} + +// IsPrefillGroupNameDuplicated 检查组名称是否重复(排除自身 ID) +func IsPrefillGroupNameDuplicated(id int, name string) (bool, error) { + if name == "" { + return false, nil + } + var cnt int64 + err := DB.Model(&PrefillGroup{}).Where("name = ? AND id <> ?", name, id).Count(&cnt).Error + return cnt > 0, err +} + +// Update 更新组 +func (g *PrefillGroup) Update() error { + g.UpdatedTime = common.GetTimestamp() + return DB.Save(g).Error +} + +// DeleteByID 根据 ID 删除组 +func DeletePrefillGroupByID(id int) error { + return DB.Delete(&PrefillGroup{}, id).Error +} + +// GetAllPrefillGroups 获取全部组,可按类型过滤(为空则返回全部) +func GetAllPrefillGroups(groupType string) ([]*PrefillGroup, error) { + var groups []*PrefillGroup + query := DB.Model(&PrefillGroup{}) + if groupType != "" { + query = query.Where("type = ?", groupType) + } + if err := query.Order("updated_time DESC").Find(&groups).Error; err != nil { + return nil, err + } + return groups, nil +} diff --git a/model/pricing.go b/model/pricing.go index a280b524..2b3920ba 100644 --- a/model/pricing.go +++ b/model/pricing.go @@ -1,30 +1,50 @@ package model import ( - "fmt" - "one-api/common" - "one-api/constant" - "one-api/setting/ratio_setting" - "one-api/types" - "sync" - "time" + "encoding/json" + "fmt" + "strings" + + "one-api/common" + "one-api/constant" + "one-api/setting/ratio_setting" + "one-api/types" + "sync" + "time" ) type Pricing struct { - ModelName string `json:"model_name"` - QuotaType int `json:"quota_type"` - ModelRatio float64 `json:"model_ratio"` - ModelPrice float64 `json:"model_price"` - OwnerBy string `json:"owner_by"` - CompletionRatio float64 `json:"completion_ratio"` - EnableGroup []string `json:"enable_groups"` - SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"` + ModelName string `json:"model_name"` + Description string `json:"description,omitempty"` + Tags string `json:"tags,omitempty"` + VendorID int `json:"vendor_id,omitempty"` + QuotaType int `json:"quota_type"` + ModelRatio float64 `json:"model_ratio"` + ModelPrice float64 `json:"model_price"` + OwnerBy string `json:"owner_by"` + CompletionRatio float64 `json:"completion_ratio"` + EnableGroup []string `json:"enable_groups"` + SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"` +} + +type PricingVendor struct { + ID int `json:"id"` + Name string `json:"name"` + Description string `json:"description,omitempty"` + Icon string `json:"icon,omitempty"` } var ( - pricingMap []Pricing - lastGetPricingTime time.Time - updatePricingLock sync.Mutex + pricingMap []Pricing + vendorsList []PricingVendor + supportedEndpointMap map[string]common.EndpointInfo + lastGetPricingTime time.Time + updatePricingLock sync.Mutex + + // 缓存映射:模型名 -> 启用分组 / 计费类型 + modelEnableGroups = make(map[string][]string) + modelQuotaTypeMap = make(map[string]int) + modelEnableGroupsLock = sync.RWMutex{} ) var ( @@ -46,6 +66,15 @@ func GetPricing() []Pricing { return pricingMap } +// GetVendors 返回当前定价接口使用到的供应商信息 +func GetVendors() []PricingVendor { + if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 { + // 保证先刷新一次 + GetPricing() + } + return vendorsList +} + func GetModelSupportEndpointTypes(model string) []constant.EndpointType { if model == "" { return make([]constant.EndpointType, 0) @@ -65,6 +94,77 @@ func updatePricing() { common.SysError(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err)) return } + // 预加载模型元数据与供应商一次,避免循环查询 + var allMeta []Model + _ = DB.Find(&allMeta).Error + metaMap := make(map[string]*Model) + prefixList := make([]*Model, 0) + suffixList := make([]*Model, 0) + containsList := make([]*Model, 0) + for i := range allMeta { + m := &allMeta[i] + if m.NameRule == NameRuleExact { + metaMap[m.ModelName] = m + } else { + switch m.NameRule { + case NameRulePrefix: + prefixList = append(prefixList, m) + case NameRuleSuffix: + suffixList = append(suffixList, m) + case NameRuleContains: + containsList = append(containsList, m) + } + } + } + + // 将非精确规则模型匹配到 metaMap + for _, m := range prefixList { + for _, pricingModel := range enableAbilities { + if strings.HasPrefix(pricingModel.Model, m.ModelName) { + if _, exists := metaMap[pricingModel.Model]; !exists { + metaMap[pricingModel.Model] = m + } + } + } + } + for _, m := range suffixList { + for _, pricingModel := range enableAbilities { + if strings.HasSuffix(pricingModel.Model, m.ModelName) { + if _, exists := metaMap[pricingModel.Model]; !exists { + metaMap[pricingModel.Model] = m + } + } + } + } + for _, m := range containsList { + for _, pricingModel := range enableAbilities { + if strings.Contains(pricingModel.Model, m.ModelName) { + if _, exists := metaMap[pricingModel.Model]; !exists { + metaMap[pricingModel.Model] = m + } + } + } + } + + // 预加载供应商 + var vendors []Vendor + _ = DB.Find(&vendors).Error + vendorMap := make(map[int]*Vendor) + for i := range vendors { + vendorMap[vendors[i].Id] = &vendors[i] + } + + // 构建对前端友好的供应商列表 + vendorsList = make([]PricingVendor, 0, len(vendors)) + for _, v := range vendors { + vendorsList = append(vendorsList, PricingVendor{ + ID: v.Id, + Name: v.Name, + Description: v.Description, + Icon: v.Icon, + }) + } + modelGroupsMap := make(map[string]*types.Set[string]) for _, ability := range enableAbilities { @@ -79,20 +179,34 @@ func updatePricing() { //这里使用切片而不是Set,因为一个模型可能支持多个端点类型,并且第一个端点是优先使用端点 modelSupportEndpointsStr := make(map[string][]string) - for _, ability := range enableAbilities { - endpoints, ok := modelSupportEndpointsStr[ability.Model] - if !ok { - endpoints = make([]string, 0) - modelSupportEndpointsStr[ability.Model] = endpoints - } - channelTypes := common.GetEndpointTypesByChannelType(ability.ChannelType, ability.Model) - for _, channelType := range channelTypes { - if !common.StringsContains(endpoints, string(channelType)) { - endpoints = append(endpoints, string(channelType)) - } - } - modelSupportEndpointsStr[ability.Model] = endpoints - } + // 先根据已有能力填充原生端点 + for _, ability := range enableAbilities { + endpoints := modelSupportEndpointsStr[ability.Model] + channelTypes := common.GetEndpointTypesByChannelType(ability.ChannelType, ability.Model) + for _, channelType := range channelTypes { + if !common.StringsContains(endpoints, string(channelType)) { + endpoints = append(endpoints, string(channelType)) + } + } + modelSupportEndpointsStr[ability.Model] = endpoints + } + + // 再补充模型自定义端点 + for modelName, meta := range metaMap { + if strings.TrimSpace(meta.Endpoints) == "" { + continue + } + var raw map[string]interface{} + if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil { + endpoints := modelSupportEndpointsStr[modelName] + for k := range raw { + if !common.StringsContains(endpoints, k) { + endpoints = append(endpoints, k) + } + } + modelSupportEndpointsStr[modelName] = endpoints + } + } modelSupportEndpointTypes = make(map[string][]constant.EndpointType) for model, endpoints := range modelSupportEndpointsStr { @@ -102,26 +216,92 @@ func updatePricing() { supportedEndpoints = append(supportedEndpoints, endpointType) } modelSupportEndpointTypes[model] = supportedEndpoints - } + } - pricingMap = make([]Pricing, 0) - for model, groups := range modelGroupsMap { - pricing := Pricing{ - ModelName: model, - EnableGroup: groups.Items(), - SupportedEndpointTypes: modelSupportEndpointTypes[model], - } - modelPrice, findPrice := ratio_setting.GetModelPrice(model, false) - if findPrice { - pricing.ModelPrice = modelPrice - pricing.QuotaType = 1 - } else { - modelRatio, _, _ := ratio_setting.GetModelRatio(model) - pricing.ModelRatio = modelRatio - pricing.CompletionRatio = ratio_setting.GetCompletionRatio(model) - pricing.QuotaType = 0 - } - pricingMap = append(pricingMap, pricing) - } - lastGetPricingTime = time.Now() + // 构建全局 supportedEndpointMap(默认 + 自定义覆盖) + supportedEndpointMap = make(map[string]common.EndpointInfo) + // 1. 默认端点 + for _, endpoints := range modelSupportEndpointTypes { + for _, et := range endpoints { + if info, ok := common.GetDefaultEndpointInfo(et); ok { + if _, exists := supportedEndpointMap[string(et)]; !exists { + supportedEndpointMap[string(et)] = info + } + } + } + } + // 2. 自定义端点(models 表)覆盖默认 + for _, meta := range metaMap { + if strings.TrimSpace(meta.Endpoints) == "" { + continue + } + var raw map[string]interface{} + if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil { + for k, v := range raw { + switch val := v.(type) { + case string: + supportedEndpointMap[k] = common.EndpointInfo{Path: val, Method: "POST"} + case map[string]interface{}: + ep := common.EndpointInfo{Method: "POST"} + if p, ok := val["path"].(string); ok { + ep.Path = p + } + if m, ok := val["method"].(string); ok { + ep.Method = strings.ToUpper(m) + } + supportedEndpointMap[k] = ep + default: + // ignore unsupported types + } + } + } + } + + pricingMap = make([]Pricing, 0) + for model, groups := range modelGroupsMap { + pricing := Pricing{ + ModelName: model, + EnableGroup: groups.Items(), + SupportedEndpointTypes: modelSupportEndpointTypes[model], + } + + // 补充模型元数据(描述、标签、供应商、状态) + if meta, ok := metaMap[model]; ok { + // 若模型被禁用(status!=1),则直接跳过,不返回给前端 + if meta.Status != 1 { + continue + } + pricing.Description = meta.Description + pricing.Tags = meta.Tags + pricing.VendorID = meta.VendorID + } + modelPrice, findPrice := ratio_setting.GetModelPrice(model, false) + if findPrice { + pricing.ModelPrice = modelPrice + pricing.QuotaType = 1 + } else { + modelRatio, _, _ := ratio_setting.GetModelRatio(model) + pricing.ModelRatio = modelRatio + pricing.CompletionRatio = ratio_setting.GetCompletionRatio(model) + pricing.QuotaType = 0 + } + pricingMap = append(pricingMap, pricing) + } + + // 刷新缓存映射,供高并发快速查询 + modelEnableGroupsLock.Lock() + modelEnableGroups = make(map[string][]string) + modelQuotaTypeMap = make(map[string]int) + for _, p := range pricingMap { + modelEnableGroups[p.ModelName] = p.EnableGroup + modelQuotaTypeMap[p.ModelName] = p.QuotaType + } + modelEnableGroupsLock.Unlock() + + lastGetPricingTime = time.Now() +} + +// GetSupportedEndpointMap 返回全局端点到路径的映射 +func GetSupportedEndpointMap() map[string]common.EndpointInfo { + return supportedEndpointMap } diff --git a/model/pricing_refresh.go b/model/pricing_refresh.go new file mode 100644 index 00000000..de72a8bb --- /dev/null +++ b/model/pricing_refresh.go @@ -0,0 +1,14 @@ +package model + +// RefreshPricing 强制立即重新计算与定价相关的缓存。 +// 该方法用于需要最新数据的内部管理 API, +// 因此会绕过默认的 1 分钟延迟刷新。 +func RefreshPricing() { + updatePricingLock.Lock() + defer updatePricingLock.Unlock() + + modelSupportEndpointsLock.Lock() + defer modelSupportEndpointsLock.Unlock() + + updatePricing() +} diff --git a/model/vendor_meta.go b/model/vendor_meta.go new file mode 100644 index 00000000..fd316156 --- /dev/null +++ b/model/vendor_meta.go @@ -0,0 +1,88 @@ +package model + +import ( + "one-api/common" + + "gorm.io/gorm" +) + +// Vendor 用于存储供应商信息,供模型引用 +// Name 唯一,用于在模型中关联 +// Icon 采用 @lobehub/icons 的图标名,前端可直接渲染 +// Status 预留字段,1 表示启用 +// 本表同样遵循 3NF 设计范式 + +type Vendor struct { + Id int `json:"id"` + Name string `json:"name" gorm:"size:128;not null;uniqueIndex:uk_vendor_name,where:deleted_at IS NULL"` + Description string `json:"description,omitempty" gorm:"type:text"` + Icon string `json:"icon,omitempty" gorm:"type:varchar(128)"` + Status int `json:"status" gorm:"default:1"` + CreatedTime int64 `json:"created_time" gorm:"bigint"` + UpdatedTime int64 `json:"updated_time" gorm:"bigint"` + DeletedAt gorm.DeletedAt `json:"-" gorm:"index"` +} + +// Insert 创建新的供应商记录 +func (v *Vendor) Insert() error { + now := common.GetTimestamp() + v.CreatedTime = now + v.UpdatedTime = now + return DB.Create(v).Error +} + +// IsVendorNameDuplicated 检查供应商名称是否重复(排除自身 ID) +func IsVendorNameDuplicated(id int, name string) (bool, error) { + if name == "" { + return false, nil + } + var cnt int64 + err := DB.Model(&Vendor{}).Where("name = ? AND id <> ?", name, id).Count(&cnt).Error + return cnt > 0, err +} + +// Update 更新供应商记录 +func (v *Vendor) Update() error { + v.UpdatedTime = common.GetTimestamp() + return DB.Save(v).Error +} + +// Delete 软删除供应商 +func (v *Vendor) Delete() error { + return DB.Delete(v).Error +} + +// GetVendorByID 根据 ID 获取供应商 +func GetVendorByID(id int) (*Vendor, error) { + var v Vendor + err := DB.First(&v, id).Error + if err != nil { + return nil, err + } + return &v, nil +} + +// GetAllVendors 获取全部供应商(分页) +func GetAllVendors(offset int, limit int) ([]*Vendor, error) { + var vendors []*Vendor + err := DB.Offset(offset).Limit(limit).Find(&vendors).Error + return vendors, err +} + +// SearchVendors 按关键字搜索供应商 +func SearchVendors(keyword string, offset int, limit int) ([]*Vendor, int64, error) { + db := DB.Model(&Vendor{}) + if keyword != "" { + like := "%" + keyword + "%" + db = db.Where("name LIKE ? OR description LIKE ?", like, like) + } + var total int64 + if err := db.Count(&total).Error; err != nil { + return nil, 0, err + } + var vendors []*Vendor + if err := db.Offset(offset).Limit(limit).Order("id DESC").Find(&vendors).Error; err != nil { + return nil, 0, err + } + return vendors, total, nil +} \ No newline at end of file diff --git a/router/api-router.go b/router/api-router.go index ab7f6880..e8519e23 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -179,6 +179,16 @@ func SetApiRouter(router *gin.Engine) { { groupRoute.GET("/", controller.GetGroups) } + + prefillGroupRoute := apiRouter.Group("/prefill_group") + prefillGroupRoute.Use(middleware.AdminAuth()) + { + prefillGroupRoute.GET("/", controller.GetPrefillGroups) + prefillGroupRoute.POST("/", controller.CreatePrefillGroup) + prefillGroupRoute.PUT("/", controller.UpdatePrefillGroup) + prefillGroupRoute.DELETE("/:id", controller.DeletePrefillGroup) + } + mjRoute := apiRouter.Group("/mj") mjRoute.GET("/self", middleware.UserAuth(), controller.GetUserMidjourney) mjRoute.GET("/", middleware.AdminAuth(), controller.GetAllMidjourney) @@ -188,5 +198,28 @@ func SetApiRouter(router *gin.Engine) { taskRoute.GET("/self", middleware.UserAuth(), controller.GetUserTask) taskRoute.GET("/", middleware.AdminAuth(), controller.GetAllTask) } + + vendorRoute := apiRouter.Group("/vendors") + vendorRoute.Use(middleware.AdminAuth()) + { + vendorRoute.GET("/", controller.GetAllVendors) + vendorRoute.GET("/search", controller.SearchVendors) + vendorRoute.GET("/:id", controller.GetVendorMeta) + vendorRoute.POST("/", controller.CreateVendorMeta) + vendorRoute.PUT("/", controller.UpdateVendorMeta) + vendorRoute.DELETE("/:id", controller.DeleteVendorMeta) + } + + modelsRoute := apiRouter.Group("/models") + modelsRoute.Use(middleware.AdminAuth()) + { + modelsRoute.GET("/missing", controller.GetMissingModels) + modelsRoute.GET("/", controller.GetAllModelsMeta) + modelsRoute.GET("/search", controller.SearchModelsMeta) + modelsRoute.GET("/:id", controller.GetModelMeta) + modelsRoute.POST("/", controller.CreateModelMeta) + modelsRoute.PUT("/", controller.UpdateModelMeta) + modelsRoute.DELETE("/:id", controller.DeleteModelMeta) + } } } diff --git a/web/src/App.js b/web/src/App.js index 47304b16..bf8397ba 100644 --- a/web/src/App.js +++ b/web/src/App.js @@ -39,6 +39,7 @@ import Chat2Link from './pages/Chat2Link'; import Midjourney from './pages/Midjourney'; import Pricing from './pages/Pricing/index.js'; import Task from './pages/Task/index.js'; +import ModelPage from './pages/Model/index.js'; import Playground from './pages/Playground/index.js'; import OAuth2Callback from './components/auth/OAuth2Callback.js'; import PersonalSetting from './components/settings/PersonalSetting.js'; @@ -71,6 +72,14 @@ function App() { } /> + + + + } + /> { - const { t } = useTranslation(); - - // 初始化JSON数据 - const [jsonData, setJsonData] = useState(() => { - // 初始化时解析JSON数据 - if (value && value.trim()) { - try { - const parsed = JSON.parse(value); - return parsed; - } catch (error) { - return {}; - } - } - return {}; - }); - - // 根据键数量决定默认编辑模式 - const [editMode, setEditMode] = useState(() => { - // 如果初始JSON数据的键数量大于10个,则默认使用手动模式 - if (value && value.trim()) { - try { - const parsed = JSON.parse(value); - const keyCount = Object.keys(parsed).length; - return keyCount > 10 ? 'manual' : 'visual'; - } catch (error) { - // JSON无效时默认显示手动编辑模式 - return 'manual'; - } - } - return 'visual'; - }); - const [jsonError, setJsonError] = useState(''); - - // 数据同步 - 当value变化时总是更新jsonData(如果JSON有效) - useEffect(() => { - try { - const parsed = value && value.trim() ? JSON.parse(value) : {}; - setJsonData(parsed); - setJsonError(''); - } catch (error) { - console.log('JSON解析失败:', error.message); - setJsonError(error.message); - // JSON格式错误时不更新jsonData - } - }, [value]); - - - // 处理可视化编辑的数据变化 - const handleVisualChange = useCallback((newData) => { - setJsonData(newData); - setJsonError(''); - const jsonString = Object.keys(newData).length === 0 ? '' : JSON.stringify(newData, null, 2); - - // 通过formApi设置值(如果提供的话) - if (formApi && field) { - formApi.setValue(field, jsonString); - } - - onChange?.(jsonString); - }, [onChange, formApi, field]); - - // 处理手动编辑的数据变化 - const handleManualChange = useCallback((newValue) => { - onChange?.(newValue); - // 验证JSON格式 - if (newValue && newValue.trim()) { - try { - const parsed = JSON.parse(newValue); - setJsonError(''); - // 预先准备可视化数据,但不立即应用 - // 这样切换到可视化模式时数据已经准备好了 - } catch (error) { - setJsonError(error.message); - } - } else { - setJsonError(''); - } - }, [onChange]); - - // 切换编辑模式 - const toggleEditMode = useCallback(() => { - if (editMode === 'visual') { - // 从可视化模式切换到手动模式 - setEditMode('manual'); - } else { - // 从手动模式切换到可视化模式,需要验证JSON - try { - const parsed = value && value.trim() ? JSON.parse(value) : {}; - setJsonData(parsed); - setJsonError(''); - setEditMode('visual'); - } catch (error) { - setJsonError(error.message); - // JSON格式错误时不切换模式 - return; - } - } - }, [editMode, value]); - - // 添加键值对 - const addKeyValue = useCallback(() => { - const newData = { ...jsonData }; - const keys = Object.keys(newData); - let newKey = 'key'; - let counter = 1; - while (newData.hasOwnProperty(newKey)) { - newKey = `key${counter}`; - counter++; - } - newData[newKey] = ''; - handleVisualChange(newData); - }, [jsonData, handleVisualChange]); - - // 删除键值对 - const removeKeyValue = useCallback((keyToRemove) => { - const newData = { ...jsonData }; - delete newData[keyToRemove]; - handleVisualChange(newData); - }, [jsonData, handleVisualChange]); - - // 更新键名 - const updateKey = useCallback((oldKey, newKey) => { - if (oldKey === newKey) return; - const newData = { ...jsonData }; - const value = newData[oldKey]; - delete newData[oldKey]; - newData[newKey] = value; - handleVisualChange(newData); - }, [jsonData, handleVisualChange]); - - // 更新值 - const updateValue = useCallback((key, newValue) => { - const newData = { ...jsonData }; - newData[key] = newValue; - handleVisualChange(newData); - }, [jsonData, handleVisualChange]); - - // 填入模板 - const fillTemplate = useCallback(() => { - if (template) { - const templateString = JSON.stringify(template, null, 2); - - // 通过formApi设置值(如果提供的话) - if (formApi && field) { - formApi.setValue(field, templateString); - } - - // 无论哪种模式都要更新值 - onChange?.(templateString); - - // 如果是可视化模式,同时更新jsonData - if (editMode === 'visual') { - setJsonData(template); - } - - // 清除错误状态 - setJsonError(''); - } - }, [template, onChange, editMode, formApi, field]); - - // 渲染键值对编辑器 - const renderKeyValueEditor = () => { - if (typeof jsonData !== 'object' || jsonData === null) { - return ( -
-
- -
- - {t('无效的JSON数据,请检查格式')} - -
- ); - } - const entries = Object.entries(jsonData); - - return ( -
- {entries.length === 0 && ( -
-
- -
- - {t('暂无数据,点击下方按钮添加键值对')} - -
- )} - - {entries.map(([key, value], index) => ( - - - -
- {t('键名')} - updateKey(key, newKey)} - size="small" - /> -
- - -
- {t('值')} - updateValue(key, newValue)} - size="small" - /> -
- - -
-
- -
-
- ))} - -
- -
-
- ); - }; - - // 渲染对象编辑器(用于复杂JSON) - const renderObjectEditor = () => { - const entries = Object.entries(jsonData); - - return ( -
- {entries.length === 0 && ( -
-
- -
- - {t('暂无参数,点击下方按钮添加请求参数')} - -
- )} - - {entries.map(([key, value], index) => ( - - - -
- {t('参数名')} - updateKey(key, newKey)} - size="small" - /> -
- - -
- {t('参数值')} ({typeof value}) - {renderValueInput(key, value)} -
- - -
-
- -
-
- ))} - -
- -
-
- ); - }; - - // 渲染参数值输入控件 - const renderValueInput = (key, value) => { - const valueType = typeof value; - - if (valueType === 'boolean') { - return ( -
- updateValue(key, newValue)} - size="small" - /> - - {value ? t('true') : t('false')} - -
- ); - } - - if (valueType === 'number') { - return ( - updateValue(key, newValue)} - size="small" - style={{ width: '100%' }} - step={key === 'temperature' ? 0.1 : 1} - precision={key === 'temperature' ? 2 : 0} - placeholder={t('输入数字')} - /> - ); - } - - // 字符串类型或其他类型 - return ( - { - // 尝试转换为适当的类型 - let convertedValue = newValue; - if (newValue === 'true') convertedValue = true; - else if (newValue === 'false') convertedValue = false; - else if (!isNaN(newValue) && newValue !== '' && newValue !== '0') { - convertedValue = Number(newValue); - } - - updateValue(key, convertedValue); - }} - size="small" - /> - ); - }; - - // 渲染区域编辑器(特殊格式) - const renderRegionEditor = () => { - const entries = Object.entries(jsonData); - const defaultEntry = entries.find(([key]) => key === 'default'); - const modelEntries = entries.filter(([key]) => key !== 'default'); - - return ( -
- {/* 默认区域 */} - -
- {t('默认区域')} -
- updateValue('default', value)} - size="small" - /> -
- - {/* 模型专用区域 */} -
- {t('模型专用区域')} - {modelEntries.map(([modelName, region], index) => ( - - - -
- {t('模型名称')} - updateKey(modelName, newKey)} - size="small" - /> -
- - -
- {t('区域')} - updateValue(modelName, newValue)} - size="small" - /> -
- - -
-
- -
-
- ))} - -
- -
-
-
- ); - }; - - // 渲染可视化编辑器 - const renderVisualEditor = () => { - switch (editorType) { - case 'region': - return renderRegionEditor(); - case 'object': - return renderObjectEditor(); - case 'keyValue': - default: - return renderKeyValueEditor(); - } - }; - - const hasJsonError = jsonError && jsonError.trim() !== ''; - - return ( -
- {/* Label统一显示在上方 */} - {label && ( -
- {label} -
- )} - - {/* 编辑模式切换 */} -
-
- {editMode === 'visual' && ( - - {t('可视化模式')} - - )} - {editMode === 'manual' && ( - - {t('手动编辑模式')} - - )} -
-
- {template && templateLabel && ( - - )} - - - - -
-
- - {/* JSON错误提示 */} - {hasJsonError && ( - - )} - - {/* 编辑器内容 */} - {editMode === 'visual' ? ( -
- - {renderVisualEditor()} - - {/* 可视化模式下的额外文本显示在下方 */} - {extraText && ( -
- {extraText} -
- )} - {/* 隐藏的Form字段用于验证和数据绑定 */} - -
- ) : ( - - )} - - {/* 额外文本在手动编辑模式下显示 */} - {extraText && editMode === 'manual' && ( -
- {extraText} -
- )} -
- ); -}; - -export default JSONEditor; \ No newline at end of file diff --git a/web/src/components/common/ui/CardPro.js b/web/src/components/common/ui/CardPro.js index 5745b9b3..ad6dda85 100644 --- a/web/src/components/common/ui/CardPro.js +++ b/web/src/components/common/ui/CardPro.js @@ -112,6 +112,7 @@ const CardPro = ({ icon={showMobileActions ? : } type="tertiary" size="small" + theme='outline' block > {showMobileActions ? t('隐藏操作项') : t('显示操作项')} diff --git a/web/src/components/common/ui/CardTable.js b/web/src/components/common/ui/CardTable.js index 75b6df00..f91ff200 100644 --- a/web/src/components/common/ui/CardTable.js +++ b/web/src/components/common/ui/CardTable.js @@ -23,6 +23,7 @@ import { Table, Card, Skeleton, Pagination, Empty, Button, Collapsible } from '@ import { IconChevronDown, IconChevronUp } from '@douyinfe/semi-icons'; import PropTypes from 'prop-types'; import { useIsMobile } from '../../../hooks/common/useIsMobile'; +import { useMinimumLoadingTime } from '../../../hooks/common/useMinimumLoadingTime'; /** * CardTable 响应式表格组件 @@ -40,25 +41,8 @@ const CardTable = ({ }) => { const isMobile = useIsMobile(); const { t } = useTranslation(); - - const [showSkeleton, setShowSkeleton] = useState(loading); - const loadingStartRef = useRef(Date.now()); - - useEffect(() => { - if (loading) { - loadingStartRef.current = Date.now(); - setShowSkeleton(true); - } else { - const elapsed = Date.now() - loadingStartRef.current; - const remaining = Math.max(0, 500 - elapsed); - if (remaining === 0) { - setShowSkeleton(false); - } else { - const timer = setTimeout(() => setShowSkeleton(false), remaining); - return () => clearTimeout(timer); - } - } - }, [loading]); + + const showSkeleton = useMinimumLoadingTime(loading); const getRowKey = (record, index) => { if (typeof rowKey === 'function') return rowKey(record); diff --git a/web/src/components/common/ui/JSONEditor.js b/web/src/components/common/ui/JSONEditor.js new file mode 100644 index 00000000..f4f5eee9 --- /dev/null +++ b/web/src/components/common/ui/JSONEditor.js @@ -0,0 +1,669 @@ +import React, { useState, useEffect, useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; +import { + Button, + Form, + Typography, + Banner, + Tabs, + TabPane, + Card, + Input, + InputNumber, + Switch, + TextArea, + Row, + Col, + Divider, +} from '@douyinfe/semi-ui'; +import { + IconCode, + IconPlus, + IconDelete, + IconRefresh, +} from '@douyinfe/semi-icons'; + +const { Text } = Typography; + +const JSONEditor = ({ + value = '', + onChange, + field, + label, + placeholder, + extraText, + extraFooter, + showClear = true, + template, + templateLabel, + editorType = 'keyValue', + rules = [], + formApi = null, + ...props +}) => { + const { t } = useTranslation(); + + // 初始化JSON数据 + const [jsonData, setJsonData] = useState(() => { + // 初始化时解析JSON数据 + if (typeof value === 'string' && value.trim()) { + try { + const parsed = JSON.parse(value); + return parsed; + } catch (error) { + return {}; + } + } + if (typeof value === 'object' && value !== null) { + return value; + } + return {}; + }); + + // 手动模式下的本地文本缓冲,避免无效 JSON 时被外部值重置 + const [manualText, setManualText] = useState(() => { + if (typeof value === 'string') return value; + if (value && typeof value === 'object') return JSON.stringify(value, null, 2); + return ''; + }); + + // 根据键数量决定默认编辑模式 + const [editMode, setEditMode] = useState(() => { + // 如果初始JSON数据的键数量大于10个,则默认使用手动模式 + if (typeof value === 'string' && value.trim()) { + try { + const parsed = JSON.parse(value); + const keyCount = Object.keys(parsed).length; + return keyCount > 10 ? 'manual' : 'visual'; + } catch (error) { + // JSON无效时默认显示手动编辑模式 + return 'manual'; + } + } + return 'visual'; + }); + const [jsonError, setJsonError] = useState(''); + + // 数据同步 - 当value变化时总是更新jsonData(如果JSON有效) + useEffect(() => { + try { + let parsed = {}; + if (typeof value === 'string' && value.trim()) { + parsed = JSON.parse(value); + } else if (typeof value === 'object' && value !== null) { + parsed = value; + } + setJsonData(parsed); + setJsonError(''); + } catch (error) { + console.log('JSON解析失败:', error.message); + setJsonError(error.message); + // JSON格式错误时不更新jsonData + } + }, [value]); + + // 外部 value 变化时,若不在手动模式,则同步手动文本;在手动模式下不打断用户输入 + useEffect(() => { + if (editMode !== 'manual') { + if (typeof value === 'string') setManualText(value); + else if (value && typeof value === 'object') setManualText(JSON.stringify(value, null, 2)); + else setManualText(''); + } + }, [value, editMode]); + + // 处理可视化编辑的数据变化 + const handleVisualChange = useCallback((newData) => { + setJsonData(newData); + setJsonError(''); + const jsonString = Object.keys(newData).length === 0 ? '' : JSON.stringify(newData, null, 2); + + // 通过formApi设置值(如果提供的话) + if (formApi && field) { + formApi.setValue(field, jsonString); + } + + onChange?.(jsonString); + }, [onChange, formApi, field]); + + // 处理手动编辑的数据变化(无效 JSON 不阻断输入,也不立刻回传上游) + const handleManualChange = useCallback((newValue) => { + setManualText(newValue); + if (newValue && newValue.trim()) { + try { + JSON.parse(newValue); + setJsonError(''); + onChange?.(newValue); + } catch (error) { + setJsonError(error.message); + // 无效 JSON 时不回传,避免外部值把输入重置 + } + } else { + setJsonError(''); + onChange?.(''); + } + }, [onChange]); + + // 切换编辑模式 + const toggleEditMode = useCallback(() => { + if (editMode === 'visual') { + // 从可视化模式切换到手动模式 + setManualText(Object.keys(jsonData).length === 0 ? '' : JSON.stringify(jsonData, null, 2)); + setEditMode('manual'); + } else { + // 从手动模式切换到可视化模式,需要验证JSON + try { + let parsed = {}; + if (manualText && manualText.trim()) { + parsed = JSON.parse(manualText); + } else if (typeof value === 'string' && value.trim()) { + parsed = JSON.parse(value); + } else if (typeof value === 'object' && value !== null) { + parsed = value; + } + setJsonData(parsed); + setJsonError(''); + setEditMode('visual'); + } catch (error) { + setJsonError(error.message); + // JSON格式错误时不切换模式 + return; + } + } + }, [editMode, value, manualText, jsonData]); + + // 添加键值对 + const addKeyValue = useCallback(() => { + const newData = { ...jsonData }; + const keys = Object.keys(newData); + let counter = 1; + let newKey = `field_${counter}`; + while (newData.hasOwnProperty(newKey)) { + counter += 1; + newKey = `field_${counter}`; + } + newData[newKey] = ''; + handleVisualChange(newData); + }, [jsonData, handleVisualChange]); + + // 删除键值对 + const removeKeyValue = useCallback((keyToRemove) => { + const newData = { ...jsonData }; + delete newData[keyToRemove]; + handleVisualChange(newData); + }, [jsonData, handleVisualChange]); + + // 更新键名 + const updateKey = useCallback((oldKey, newKey) => { + if (oldKey === newKey || !newKey) return; + const newData = {}; + Object.entries(jsonData).forEach(([k, v]) => { + if (k === oldKey) { + newData[newKey] = v; + } else { + newData[k] = v; + } + }); + handleVisualChange(newData); + }, [jsonData, handleVisualChange]); + + // 更新值 + const updateValue = useCallback((key, newValue) => { + const newData = { ...jsonData }; + newData[key] = newValue; + handleVisualChange(newData); + }, [jsonData, handleVisualChange]); + + // 填入模板 + const fillTemplate = useCallback(() => { + if (template) { + const templateString = JSON.stringify(template, null, 2); + + // 通过formApi设置值(如果提供的话) + if (formApi && field) { + formApi.setValue(field, templateString); + } + + // 同步内部与外部值,避免出现杂字符 + setManualText(templateString); + setJsonData(template); + onChange?.(templateString); + + // 清除错误状态 + setJsonError(''); + } + }, [template, onChange, editMode, formApi, field]); + + // 渲染键值对编辑器 + const renderKeyValueEditor = () => { + if (typeof jsonData !== 'object' || jsonData === null) { + return ( +
+
+ +
+ + {t('无效的JSON数据,请检查格式')} + +
+ ); + } + const entries = Object.entries(jsonData); + + return ( +
+ {entries.length === 0 && ( +
+ + {t('暂无数据,点击下方按钮添加键值对')} + +
+ )} + + {entries.map(([key, value], index) => ( + + + updateKey(key, newKey)} + /> + + + {renderValueInput(key, value)} + + + +
+ + ); + }; + + // 添加嵌套对象 + const flattenObject = useCallback((parentKey) => { + const newData = { ...jsonData }; + let primitive = ''; + const obj = newData[parentKey]; + if (obj && typeof obj === 'object') { + const firstKey = Object.keys(obj)[0]; + if (firstKey !== undefined) { + const firstVal = obj[firstKey]; + if (typeof firstVal !== 'object') primitive = firstVal; + } + } + newData[parentKey] = primitive; + handleVisualChange(newData); + }, [jsonData, handleVisualChange]); + + const addNestedObject = useCallback((parentKey) => { + const newData = { ...jsonData }; + if (typeof newData[parentKey] !== 'object' || newData[parentKey] === null) { + newData[parentKey] = {}; + } + const existingKeys = Object.keys(newData[parentKey]); + let counter = 1; + let newKey = `field_${counter}`; + while (newData[parentKey].hasOwnProperty(newKey)) { + counter += 1; + newKey = `field_${counter}`; + } + newData[parentKey][newKey] = ''; + handleVisualChange(newData); + }, [jsonData, handleVisualChange]); + + // 渲染参数值输入控件(支持嵌套) + const renderValueInput = (key, value) => { + const valueType = typeof value; + + if (valueType === 'boolean') { + return ( +
+ updateValue(key, newValue)} + /> + + {value ? t('true') : t('false')} + +
+ ); + } + + if (valueType === 'number') { + return ( + updateValue(key, newValue)} + style={{ width: '100%' }} + step={key === 'temperature' ? 0.1 : 1} + precision={key === 'temperature' ? 2 : 0} + placeholder={t('输入数字')} + /> + ); + } + + if (valueType === 'object' && value !== null) { + // 渲染嵌套对象 + const entries = Object.entries(value); + return ( + + {entries.length === 0 && ( + + {t('空对象,点击下方加号添加字段')} + + )} + + {entries.map(([nestedKey, nestedValue], index) => ( + + + { + const newData = { ...jsonData }; + const oldValue = newData[key][nestedKey]; + delete newData[key][nestedKey]; + newData[key][newKey] = oldValue; + handleVisualChange(newData); + }} + /> + + + {typeof nestedValue === 'object' && nestedValue !== null ? ( +