diff --git a/common/utils.go b/common/utils.go index d9db67d0..17aecd95 100644 --- a/common/utils.go +++ b/common/utils.go @@ -13,6 +13,7 @@ import ( "math/big" "math/rand" "net" + "net/url" "os" "os/exec" "runtime" @@ -284,3 +285,20 @@ func GetAudioDuration(ctx context.Context, filename string, ext string) (float64 } return strconv.ParseFloat(durationStr, 64) } + +// BuildURL concatenates base and endpoint, returns the complete url string +func BuildURL(base string, endpoint string) string { + u, err := url.Parse(base) + if err != nil { + return base + endpoint + } + end := endpoint + if end == "" { + end = "/" + } + ref, err := url.Parse(end) + if err != nil { + return base + endpoint + } + return u.ResolveReference(ref).String() +} diff --git a/controller/ratio_config.go b/controller/ratio_config.go new file mode 100644 index 00000000..6ddc3d9e --- /dev/null +++ b/controller/ratio_config.go @@ -0,0 +1,24 @@ +package controller + +import ( + "net/http" + "one-api/setting/ratio_setting" + + "github.com/gin-gonic/gin" +) + +func GetRatioConfig(c *gin.Context) { + if !ratio_setting.IsExposeRatioEnabled() { + c.JSON(http.StatusForbidden, gin.H{ + "success": false, + "message": "倍率配置接口未启用", + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": ratio_setting.GetExposedData(), + }) +} \ No newline at end of file diff --git a/controller/ratio_sync.go b/controller/ratio_sync.go new file mode 100644 index 00000000..c7494b5b --- /dev/null +++ b/controller/ratio_sync.go @@ -0,0 +1,290 @@ +package controller + +import ( + "encoding/json" + "net/http" + "one-api/model" + "one-api/setting/ratio_setting" + "one-api/dto" + "sync" + "time" + + "github.com/gin-gonic/gin" +) + +type upstreamResult struct { + Name string `json:"name"` + Data map[string]any `json:"data,omitempty"` + Err string `json:"err,omitempty"` +} + +type TestResult struct { + Name string `json:"name"` + Status string `json:"status"` + Error string `json:"error,omitempty"` +} + +type DifferenceItem struct { + Current interface{} `json:"current"` // 当前本地值,可能为null + Upstreams map[string]interface{} `json:"upstreams"` // 上游值:具体值/"same"/null +} + +// SyncableChannel 可同步的渠道信息 +type SyncableChannel struct { + ID int `json:"id"` + Name string `json:"name"` + BaseURL string `json:"base_url"` + Status int `json:"status"` +} + +// FetchUpstreamRatios 后端并发拉取上游倍率 +func FetchUpstreamRatios(c *gin.Context) { + var req dto.UpstreamRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()}) + return + } + + if req.Timeout <= 0 { + req.Timeout = 10 + } + + // build upstream list from ids + custom + var upstreams []dto.UpstreamDTO + if len(req.ChannelIDs) > 0 { + // convert []int64 -> []int for model function + intIds := make([]int, 0, len(req.ChannelIDs)) + for _, id64 := range req.ChannelIDs { + intIds = append(intIds, int(id64)) + } + dbChannels, _ := model.GetChannelsByIds(intIds) + for _, ch := range dbChannels { + upstreams = append(upstreams, dto.UpstreamDTO{ + Name: ch.Name, + BaseURL: ch.GetBaseURL(), + Endpoint: "", // assume default endpoint + }) + } + } + upstreams = append(upstreams, req.CustomChannels...) + + var wg sync.WaitGroup + ch := make(chan upstreamResult, len(upstreams)) + + for _, chn := range upstreams { + wg.Add(1) + go func(chItem dto.UpstreamDTO) { + defer wg.Done() + endpoint := chItem.Endpoint + if endpoint == "" { + endpoint = "/api/ratio_config" + } + url := chItem.BaseURL + endpoint + client := http.Client{Timeout: time.Duration(req.Timeout) * time.Second} + resp, err := client.Get(url) + if err != nil { + ch <- upstreamResult{Name: chItem.Name, Err: err.Error()} + return + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + ch <- upstreamResult{Name: chItem.Name, Err: resp.Status} + return + } + var body struct { + Success bool `json:"success"` + Data map[string]any `json:"data"` + Message string `json:"message"` + } + if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { + ch <- upstreamResult{Name: chItem.Name, Err: err.Error()} + return + } + if !body.Success { + ch <- upstreamResult{Name: chItem.Name, Err: body.Message} + return + } + ch <- upstreamResult{Name: chItem.Name, Data: body.Data} + }(chn) + } + + wg.Wait() + close(ch) + + // 本地倍率配置 + localData := ratio_setting.GetExposedData() + + var testResults []dto.TestResult + var successfulChannels []struct { + name string + data map[string]any + } + + for r := range ch { + if r.Err != "" { + testResults = append(testResults, dto.TestResult{ + Name: r.Name, + Status: "error", + Error: r.Err, + }) + } else { + testResults = append(testResults, dto.TestResult{ + Name: r.Name, + Status: "success", + }) + successfulChannels = append(successfulChannels, struct { + name string + data map[string]any + }{name: r.Name, data: r.Data}) + } + } + + // 构建差异化数据 + differences := buildDifferences(localData, successfulChannels) + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "data": gin.H{ + "differences": differences, + "test_results": testResults, + }, + }) +} + +// buildDifferences 构建差异化数据,只返回有意义的差异 +func buildDifferences(localData map[string]any, successfulChannels []struct { + name string + data map[string]any +}) map[string]map[string]dto.DifferenceItem { + differences := make(map[string]map[string]dto.DifferenceItem) + ratioTypes := []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"} + + // 收集所有模型名称 + allModels := make(map[string]struct{}) + + // 从本地数据收集模型名称 + for _, ratioType := range ratioTypes { + if localRatioAny, ok := localData[ratioType]; ok { + if localRatio, ok := localRatioAny.(map[string]float64); ok { + for modelName := range localRatio { + allModels[modelName] = struct{}{} + } + } + } + } + + // 从上游数据收集模型名称 + for _, channel := range successfulChannels { + for _, ratioType := range ratioTypes { + if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok { + for modelName := range upstreamRatio { + allModels[modelName] = struct{}{} + } + } + } + } + + // 对每个模型和每个比率类型进行分析 + for modelName := range allModels { + for _, ratioType := range ratioTypes { + // 获取本地值 + var localValue interface{} = nil + if localRatioAny, ok := localData[ratioType]; ok { + if localRatio, ok := localRatioAny.(map[string]float64); ok { + if val, exists := localRatio[modelName]; exists { + localValue = val + } + } + } + + // 收集上游值 + upstreamValues := make(map[string]interface{}) + hasUpstreamValue := false + hasDifference := false + + for _, channel := range successfulChannels { + var upstreamValue interface{} = nil + + if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok { + if val, exists := upstreamRatio[modelName]; exists { + upstreamValue = val + hasUpstreamValue = true + + // 检查是否与本地值不同 + if localValue != nil && localValue != val { + hasDifference = true + } else if localValue == val { + upstreamValue = "same" + } + } + } + + // 如果本地值为空但上游有值,这也是差异 + if localValue == nil && upstreamValue != nil && upstreamValue != "same" { + hasDifference = true + } + + upstreamValues[channel.name] = upstreamValue + } + + // 应用过滤逻辑 + shouldInclude := false + + if localValue != nil { + // 规则1: 本地值存在,至少有一个上游与本地值不同 + if hasDifference { + shouldInclude = true + } + // 规则2: 本地值存在,但所有上游都未设置 - 不包含 + } else { + // 规则3: 本地值不存在,至少有一个上游设置了值 + if hasUpstreamValue { + shouldInclude = true + } + } + + if shouldInclude { + if differences[modelName] == nil { + differences[modelName] = make(map[string]dto.DifferenceItem) + } + differences[modelName][ratioType] = dto.DifferenceItem{ + Current: localValue, + Upstreams: upstreamValues, + } + } + } + } + + return differences +} + +// GetSyncableChannels 获取可用于倍率同步的渠道(base_url 不为空的渠道) +func GetSyncableChannels(c *gin.Context) { + channels, err := model.GetAllChannels(0, 0, true, false) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + var syncableChannels []dto.SyncableChannel + for _, channel := range channels { + // 只返回 base_url 不为空的渠道 + if channel.GetBaseURL() != "" { + syncableChannels = append(syncableChannels, dto.SyncableChannel{ + ID: channel.Id, + Name: channel.Name, + BaseURL: channel.GetBaseURL(), + Status: channel.Status, + }) + } + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": syncableChannels, + }) +} \ No newline at end of file diff --git a/dto/ratio_sync.go b/dto/ratio_sync.go new file mode 100644 index 00000000..4f2fe06d --- /dev/null +++ b/dto/ratio_sync.go @@ -0,0 +1,50 @@ +package dto + +// UpstreamDTO 提交到后端同步倍率的上游渠道信息 +// Endpoint 可以为空,后端会默认使用 /api/ratio_config +// BaseURL 必须以 http/https 开头,不要以 / 结尾 +// 例如: https://api.example.com +// Endpoint: /api/ratio_config +// 提交示例: +// { +// "name": "openai", +// "base_url": "https://api.openai.com", +// "endpoint": "/ratio_config" +// } + +type UpstreamDTO struct { + Name string `json:"name" binding:"required"` + BaseURL string `json:"base_url" binding:"required"` + Endpoint string `json:"endpoint"` +} + +type UpstreamRequest struct { + ChannelIDs []int64 `json:"channel_ids"` + CustomChannels []UpstreamDTO `json:"custom_channels"` + Timeout int `json:"timeout"` +} + +// TestResult 上游测试连通性结果 +type TestResult struct { + Name string `json:"name"` + Status string `json:"status"` + Error string `json:"error,omitempty"` +} + +// DifferenceItem 差异项 +// Current 为本地值,可能为 nil +// Upstreams 为各渠道的上游值,具体数值 / "same" / nil + +type DifferenceItem struct { + Current interface{} `json:"current"` + Upstreams map[string]interface{} `json:"upstreams"` +} + +// SyncableChannel 可同步的渠道信息(base_url 不为空) + +type SyncableChannel struct { + ID int `json:"id"` + Name string `json:"name"` + BaseURL string `json:"base_url"` + Status int `json:"status"` +} \ No newline at end of file diff --git a/model/option.go b/model/option.go index 43c0a644..97f7baae 100644 --- a/model/option.go +++ b/model/option.go @@ -126,6 +126,7 @@ func InitOptionMap() { common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString() common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength) common.OptionMap["AutomaticDisableKeywords"] = operation_setting.AutomaticDisableKeywordsToString() + common.OptionMap["ExposeRatioEnabled"] = strconv.FormatBool(ratio_setting.IsExposeRatioEnabled()) // 自动添加所有注册的模型配置 modelConfigs := config.GlobalConfig.ExportAllConfigs() @@ -266,6 +267,8 @@ func updateOptionMap(key string, value string) (err error) { setting.WorkerAllowHttpImageRequestEnabled = boolValue case "DefaultUseAutoGroup": setting.DefaultUseAutoGroup = boolValue + case "ExposeRatioEnabled": + ratio_setting.SetExposeRatioEnabled(boolValue) } } switch key { diff --git a/router/api-router.go b/router/api-router.go index 45930246..badfa7bf 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -36,6 +36,7 @@ func SetApiRouter(router *gin.Engine) { apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), controller.EmailBind) apiRouter.GET("/oauth/telegram/login", middleware.CriticalRateLimit(), controller.TelegramLogin) apiRouter.GET("/oauth/telegram/bind", middleware.CriticalRateLimit(), controller.TelegramBind) + apiRouter.GET("/ratio_config", middleware.CriticalRateLimit(), controller.GetRatioConfig) userRoute := apiRouter.Group("/user") { @@ -83,6 +84,12 @@ func SetApiRouter(router *gin.Engine) { optionRoute.POST("/rest_model_ratio", controller.ResetModelRatio) optionRoute.POST("/migrate_console_setting", controller.MigrateConsoleSetting) // 用于迁移检测的旧键,下个版本会删除 } + ratioSyncRoute := apiRouter.Group("/ratio_sync") + ratioSyncRoute.Use(middleware.RootAuth()) + { + ratioSyncRoute.GET("/channels", controller.GetSyncableChannels) + ratioSyncRoute.POST("/fetch", controller.FetchUpstreamRatios) + } channelRoute := apiRouter.Group("/channel") channelRoute.Use(middleware.AdminAuth()) { diff --git a/setting/ratio_setting/cache_ratio.go b/setting/ratio_setting/cache_ratio.go index aa934b22..51d473a8 100644 --- a/setting/ratio_setting/cache_ratio.go +++ b/setting/ratio_setting/cache_ratio.go @@ -85,7 +85,11 @@ func UpdateCacheRatioByJSONString(jsonStr string) error { cacheRatioMapMutex.Lock() defer cacheRatioMapMutex.Unlock() cacheRatioMap = make(map[string]float64) - return json.Unmarshal([]byte(jsonStr), &cacheRatioMap) + err := json.Unmarshal([]byte(jsonStr), &cacheRatioMap) + if err == nil { + InvalidateExposedDataCache() + } + return err } // GetCacheRatio returns the cache ratio for a model @@ -106,3 +110,13 @@ func GetCreateCacheRatio(name string) (float64, bool) { } return ratio, true } + +func GetCacheRatioCopy() map[string]float64 { + cacheRatioMapMutex.RLock() + defer cacheRatioMapMutex.RUnlock() + copyMap := make(map[string]float64, len(cacheRatioMap)) + for k, v := range cacheRatioMap { + copyMap[k] = v + } + return copyMap +} diff --git a/setting/ratio_setting/expose_ratio.go b/setting/ratio_setting/expose_ratio.go new file mode 100644 index 00000000..8fca0bcb --- /dev/null +++ b/setting/ratio_setting/expose_ratio.go @@ -0,0 +1,17 @@ +package ratio_setting + +import "sync/atomic" + +var exposeRatioEnabled atomic.Bool + +func init() { + exposeRatioEnabled.Store(false) +} + +func SetExposeRatioEnabled(enabled bool) { + exposeRatioEnabled.Store(enabled) +} + +func IsExposeRatioEnabled() bool { + return exposeRatioEnabled.Load() +} \ No newline at end of file diff --git a/setting/ratio_setting/exposed_cache.go b/setting/ratio_setting/exposed_cache.go new file mode 100644 index 00000000..9e5b6c30 --- /dev/null +++ b/setting/ratio_setting/exposed_cache.go @@ -0,0 +1,55 @@ +package ratio_setting + +import ( + "sync" + "sync/atomic" + "time" + + "github.com/gin-gonic/gin" +) + +const exposedDataTTL = 30 * time.Second + +type exposedCache struct { + data gin.H + expiresAt time.Time +} + +var ( + exposedData atomic.Value + rebuildMu sync.Mutex +) + +func InvalidateExposedDataCache() { + exposedData.Store((*exposedCache)(nil)) +} + +func cloneGinH(src gin.H) gin.H { + dst := make(gin.H, len(src)) + for k, v := range src { + dst[k] = v + } + return dst +} + +func GetExposedData() gin.H { + if c, ok := exposedData.Load().(*exposedCache); ok && c != nil && time.Now().Before(c.expiresAt) { + return cloneGinH(c.data) + } + rebuildMu.Lock() + defer rebuildMu.Unlock() + if c, ok := exposedData.Load().(*exposedCache); ok && c != nil && time.Now().Before(c.expiresAt) { + return cloneGinH(c.data) + } + newData := gin.H{ + "model_ratio": GetModelRatioCopy(), + "completion_ratio": GetCompletionRatioCopy(), + "cache_ratio": GetCacheRatioCopy(), + "model_price": GetModelPriceCopy(), + } + exposedData.Store(&exposedCache{ + data: newData, + expiresAt: time.Now().Add(exposedDataTTL), + }) + return cloneGinH(newData) +} \ No newline at end of file diff --git a/setting/ratio_setting/model_ratio.go b/setting/ratio_setting/model_ratio.go index 3102dfe9..1eaf25b1 100644 --- a/setting/ratio_setting/model_ratio.go +++ b/setting/ratio_setting/model_ratio.go @@ -317,7 +317,11 @@ func UpdateModelPriceByJSONString(jsonStr string) error { modelPriceMapMutex.Lock() defer modelPriceMapMutex.Unlock() modelPriceMap = make(map[string]float64) - return json.Unmarshal([]byte(jsonStr), &modelPriceMap) + err := json.Unmarshal([]byte(jsonStr), &modelPriceMap) + if err == nil { + InvalidateExposedDataCache() + } + return err } // GetModelPrice 返回模型的价格,如果模型不存在则返回-1,false @@ -345,7 +349,11 @@ func UpdateModelRatioByJSONString(jsonStr string) error { modelRatioMapMutex.Lock() defer modelRatioMapMutex.Unlock() modelRatioMap = make(map[string]float64) - return json.Unmarshal([]byte(jsonStr), &modelRatioMap) + err := json.Unmarshal([]byte(jsonStr), &modelRatioMap) + if err == nil { + InvalidateExposedDataCache() + } + return err } // 处理带有思考预算的模型名称,方便统一定价 @@ -405,7 +413,11 @@ func UpdateCompletionRatioByJSONString(jsonStr string) error { CompletionRatioMutex.Lock() defer CompletionRatioMutex.Unlock() CompletionRatio = make(map[string]float64) - return json.Unmarshal([]byte(jsonStr), &CompletionRatio) + err := json.Unmarshal([]byte(jsonStr), &CompletionRatio) + if err == nil { + InvalidateExposedDataCache() + } + return err } func GetCompletionRatio(name string) float64 { @@ -609,3 +621,33 @@ func GetImageRatio(name string) (float64, bool) { } return ratio, true } + +func GetModelRatioCopy() map[string]float64 { + modelRatioMapMutex.RLock() + defer modelRatioMapMutex.RUnlock() + copyMap := make(map[string]float64, len(modelRatioMap)) + for k, v := range modelRatioMap { + copyMap[k] = v + } + return copyMap +} + +func GetModelPriceCopy() map[string]float64 { + modelPriceMapMutex.RLock() + defer modelPriceMapMutex.RUnlock() + copyMap := make(map[string]float64, len(modelPriceMap)) + for k, v := range modelPriceMap { + copyMap[k] = v + } + return copyMap +} + +func GetCompletionRatioCopy() map[string]float64 { + CompletionRatioMutex.RLock() + defer CompletionRatioMutex.RUnlock() + copyMap := make(map[string]float64, len(CompletionRatio)) + for k, v := range CompletionRatio { + copyMap[k] = v + } + return copyMap +} diff --git a/web/src/components/settings/ChannelSelectorModal.js b/web/src/components/settings/ChannelSelectorModal.js new file mode 100644 index 00000000..c393d97f --- /dev/null +++ b/web/src/components/settings/ChannelSelectorModal.js @@ -0,0 +1,154 @@ +import React from 'react'; +import { + Modal, + Transfer, + Input, + Card, + Space, + Button, + Checkbox, +} from '@douyinfe/semi-ui'; +import { IconPlus, IconClose } from '@douyinfe/semi-icons'; + +/** + * ChannelSelectorModal + * 负责选择同步渠道、测试与批量测试等 UI,纯展示组件。 + * 业务状态与动作通过 props 注入,保持可复用与可测试。 + */ +export default function ChannelSelectorModal({ + t, + visible, + onCancel, + onOk, + // 渠道与选择 + allChannels = [], + selectedChannelIds = [], + setSelectedChannelIds, + // 自定义渠道 + customUrl, + setCustomUrl, + customEndpoint, + setCustomEndpoint, + customChannelTesting, + addCustomChannel, + // 渠道端点 + channelEndpoints, + updateChannelEndpoint, + // 测试相关 +}) { + // Transfer 自定义渲染 + const renderSourceItem = (item) => { + const channelId = item.key || item.value; + const currentEndpoint = channelEndpoints[channelId]; + const baseUrl = item._originalData?.base_url || ''; + + return ( +
+
+
+ + {item.label} + +
+
+ + {baseUrl} + + updateChannelEndpoint(channelId, value)} + placeholder="/api/ratio_config" + className="flex-1 text-xs" + style={{ fontSize: '12px' }} + /> +
+
+
+ ); + }; + + const renderSelectedItem = (item) => { + const channelId = item.key || item.value; + const currentEndpoint = channelEndpoints[channelId]; + const baseUrl = item._originalData?.base_url || ''; + + return ( +
+
+
+ {item.label} + +
+
+ + {baseUrl} + + + {currentEndpoint} + +
+
+
+ ); + }; + + const channelFilter = (input, item) => item.label.toLowerCase().includes(input.toLowerCase()); + + return ( + {t('选择同步渠道')}} + width={1000} + > + + + + + + + + + + + + + ); +} \ No newline at end of file diff --git a/web/src/components/settings/RatioSetting.js b/web/src/components/settings/RatioSetting.js index bf97282c..1d87c6de 100644 --- a/web/src/components/settings/RatioSetting.js +++ b/web/src/components/settings/RatioSetting.js @@ -6,6 +6,7 @@ import GroupRatioSettings from '../../pages/Setting/Ratio/GroupRatioSettings.js' import ModelRatioSettings from '../../pages/Setting/Ratio/ModelRatioSettings.js'; import ModelSettingsVisualEditor from '../../pages/Setting/Ratio/ModelSettingsVisualEditor.js'; import ModelRatioNotSetEditor from '../../pages/Setting/Ratio/ModelRationNotSetEditor.js'; +import UpstreamRatioSync from '../../pages/Setting/Ratio/UpstreamRatioSync.js'; import { API, showError } from '../../helpers'; @@ -21,6 +22,7 @@ const RatioSetting = () => { GroupGroupRatio: '', AutoGroups: '', DefaultUseAutoGroup: false, + ExposeRatioEnabled: false, UserUsableGroups: '', }); @@ -48,7 +50,7 @@ const RatioSetting = () => { // 如果后端返回的不是合法 JSON,直接展示 } } - if (['DefaultUseAutoGroup'].includes(item.key)) { + if (['DefaultUseAutoGroup', 'ExposeRatioEnabled'].includes(item.key)) { newInputs[item.key] = item.value === 'true' ? true : false; } else { newInputs[item.key] = item.value; @@ -78,10 +80,6 @@ const RatioSetting = () => { return ( - {/* 分组倍率设置 */} - - - {/* 模型倍率设置以及可视化编辑器 */} @@ -100,8 +98,18 @@ const RatioSetting = () => { refresh={onRefresh} /> + + + + {/* 分组倍率设置 */} + + + ); }; diff --git a/web/src/helpers/ratio.js b/web/src/helpers/ratio.js new file mode 100644 index 00000000..fb293c80 --- /dev/null +++ b/web/src/helpers/ratio.js @@ -0,0 +1,20 @@ +export const DEFAULT_ENDPOINT = '/api/ratio_config'; + +/** + * buildEndpointUrl: 拼接 baseUrl 与 endpoint,确保不会出现双斜杠或缺失斜杠问题。 + * 使用 URL 构造函数保证协议/域名安全;若 baseUrl 非标准 URL,则退回字符串拼接。 + * @param {string} baseUrl - 基础地址,例如 https://api.example.com + * @param {string} endpoint - 接口路径,例如 /api/ratio_config + * @returns {string} + */ +export const buildEndpointUrl = (baseUrl, endpoint) => { + if (!baseUrl) return endpoint; + try { + return new URL(endpoint, baseUrl).toString(); + } catch (_) { + // fallback 处理不规范的 baseUrl + const cleanedBase = baseUrl.endsWith('/') ? baseUrl.slice(0, -1) : baseUrl; + const cleanedEndpoint = endpoint.startsWith('/') ? endpoint.slice(1) : endpoint; + return `${cleanedBase}/${cleanedEndpoint}`; + } +}; \ No newline at end of file diff --git a/web/src/pages/Detail/index.js b/web/src/pages/Detail/index.js index 15c02abf..0fd18d16 100644 --- a/web/src/pages/Detail/index.js +++ b/web/src/pages/Detail/index.js @@ -1112,7 +1112,6 @@ const Detail = (props) => { @@ -1389,7 +1388,6 @@ const Detail = (props) => { ) : ( + + + + setInputs({ ...inputs, ExposeRatioEnabled: value }) + } + /> + + diff --git a/web/src/pages/Setting/Ratio/UpstreamRatioSync.js b/web/src/pages/Setting/Ratio/UpstreamRatioSync.js new file mode 100644 index 00000000..518c6468 --- /dev/null +++ b/web/src/pages/Setting/Ratio/UpstreamRatioSync.js @@ -0,0 +1,596 @@ +import React, { useState, useCallback, useMemo } from 'react'; +import { + Button, + Table, + Tag, + Empty, + Checkbox, + Form, +} from '@douyinfe/semi-ui'; +import { + RefreshCcw, + CheckSquare, +} from 'lucide-react'; +import { + DEFAULT_ENDPOINT, + buildEndpointUrl, +} from '../../../helpers/ratio'; +import { API, showError, showSuccess, showWarning } from '../../../helpers'; +import { useTranslation } from 'react-i18next'; +import { + IllustrationNoResult, + IllustrationNoResultDark +} from '@douyinfe/semi-illustrations'; +import ChannelSelectorModal from '../../../components/settings/ChannelSelectorModal'; + +export default function UpstreamRatioSync(props) { + const { t } = useTranslation(); + const [modalVisible, setModalVisible] = useState(false); + const [loading, setLoading] = useState(false); + const [syncLoading, setSyncLoading] = useState(false); + + // 渠道选择相关 + const [allChannels, setAllChannels] = useState([]); + const [selectedChannelIds, setSelectedChannelIds] = useState([]); + + // 自定义渠道 + const [customUrl, setCustomUrl] = useState(''); + const [customEndpoint, setCustomEndpoint] = useState(DEFAULT_ENDPOINT); + const [customChannelTesting, setCustomChannelTesting] = useState(false); + + // 渠道端点配置 + const [channelEndpoints, setChannelEndpoints] = useState({}); // { channelId: endpoint } + + // 差异数据和测试结果 + const [differences, setDifferences] = useState({}); + const [testResults, setTestResults] = useState([]); + const [resolutions, setResolutions] = useState({}); + + // 分页相关状态 + const [currentPage, setCurrentPage] = useState(1); + const [pageSize, setPageSize] = useState(10); + + // 当前倍率快照 + const currentRatiosSnapshot = useMemo(() => ({ + model_ratio: JSON.parse(props.options.ModelRatio || '{}'), + completion_ratio: JSON.parse(props.options.CompletionRatio || '{}'), + cache_ratio: JSON.parse(props.options.CacheRatio || '{}'), + model_price: JSON.parse(props.options.ModelPrice || '{}'), + }), [props.options]); + + // 获取所有渠道 + const fetchAllChannels = async () => { + setLoading(true); + try { + const res = await API.get('/api/ratio_sync/channels'); + + if (res.data.success) { + const channels = res.data.data || []; + + // 转换为Transfer组件所需格式 + const transferData = channels.map(channel => ({ + key: channel.id, + label: channel.name, + value: channel.id, + disabled: false, // 所有渠道都可以选择 + _originalData: channel, + })); + + setAllChannels(transferData); + + // 初始化端点配置 + const initialEndpoints = {}; + transferData.forEach(channel => { + initialEndpoints[channel.key] = DEFAULT_ENDPOINT; + }); + setChannelEndpoints(initialEndpoints); + } else { + showError(res.data.message); + } + } catch (error) { + showError(t('获取渠道失败:') + error.message); + } finally { + setLoading(false); + } + }; + + // 测试自定义渠道 + const testCustomChannel = async () => { + if (!customUrl) { + showWarning(t('请输入渠道地址')); + return false; + } + + setCustomChannelTesting(true); + + try { + const url = buildEndpointUrl(customUrl, customEndpoint); + const client = { timeout: 10000 }; + + const response = await fetch(url, { + method: 'GET', + signal: AbortSignal.timeout(client.timeout) + }); + + if (response.ok) { + const data = await response.json(); + if (data.success) { + return true; + } else { + showError(t('测试失败') + `: ${data.message || t('响应格式错误')}`); + return false; + } + } else { + showError(t('测试失败') + `: HTTP ${response.status}`); + return false; + } + } catch (error) { + showError(t('测试失败') + `: ${error.message || t('请求超时')}`); + return false; + } finally { + setCustomChannelTesting(false); + } + }; + + // 添加自定义渠道 + const addCustomChannel = async () => { + if (!customUrl) { + showWarning(t('请输入渠道地址')); + return; + } + + // 先测试渠道 + const testResult = await testCustomChannel(); + if (!testResult) { + return; + } + + let hostname; + try { + hostname = new URL(customUrl).hostname; + } catch (e) { + hostname = customUrl; + } + + const customId = `custom_${Date.now()}`; + const newChannel = { + key: customId, + label: hostname, + value: customId, + disabled: false, + _originalData: { + id: customId, + name: hostname, + base_url: customUrl.endsWith('/') ? customUrl.slice(0, -1) : customUrl, + status: 1, + is_custom: true, + }, + }; + + setAllChannels([...allChannels, newChannel]); + setSelectedChannelIds([...selectedChannelIds, customId]); + setChannelEndpoints(prev => ({ ...prev, [customId]: customEndpoint })); + setCustomUrl(''); + showSuccess(t('测试成功,渠道添加成功')); + }; + + // 确认选择渠道 + const confirmChannelSelection = () => { + const selected = allChannels + .filter(ch => selectedChannelIds.includes(ch.value)) + .map(ch => ch._originalData); + + if (selected.length === 0) { + showWarning(t('请至少选择一个渠道')); + return; + } + + setModalVisible(false); + fetchRatiosFromChannels(selected); + }; + + // 从选定渠道获取倍率 + const fetchRatiosFromChannels = async (channelList) => { + setSyncLoading(true); + + // 分离数据库渠道和自定义渠道 + const dbChannels = channelList.filter(ch => !ch.is_custom); + const customChannels = channelList.filter(ch => ch.is_custom); + + const payload = { + channel_ids: dbChannels.map(ch => parseInt(ch.id)), + custom_channels: customChannels.map(ch => ({ + name: ch.name, + base_url: ch.base_url, + endpoint: channelEndpoints[ch.id] || DEFAULT_ENDPOINT, + })), + timeout: 10 + }; + + try { + const res = await API.post('/api/ratio_sync/fetch', payload); + + if (!res.data.success) { + showError(res.data.message || t('后端请求失败')); + setSyncLoading(false); + return; + } + + const { differences = {}, test_results = [] } = res.data.data; + + // 显示测试结果 + const errorResults = test_results.filter(r => r.status === 'error'); + if (errorResults.length > 0) { + showWarning(t('部分渠道测试失败:') + errorResults.map(r => `${r.name}: ${r.error}`).join(', ')); + } + + setDifferences(differences); + setTestResults(test_results); + setResolutions({}); + + // 判断是否有差异 + if (Object.keys(differences).length === 0) { + showSuccess(t('已与上游倍率完全一致,无需同步')); + } + } catch (e) { + showError(t('请求后端接口失败:') + e.message); + } finally { + setSyncLoading(false); + } + }; + + // 解决冲突/选择值 + const selectValue = (model, ratioType, value) => { + setResolutions(prev => ({ + ...prev, + [model]: { + ...prev[model], + [ratioType]: value, + }, + })); + }; + + // 应用同步 + const applySync = async () => { + const currentRatios = { + ModelRatio: JSON.parse(props.options.ModelRatio || '{}'), + CompletionRatio: JSON.parse(props.options.CompletionRatio || '{}'), + CacheRatio: JSON.parse(props.options.CacheRatio || '{}'), + ModelPrice: JSON.parse(props.options.ModelPrice || '{}'), + }; + + // 应用已选择的值 + Object.entries(resolutions).forEach(([model, ratios]) => { + Object.entries(ratios).forEach(([ratioType, value]) => { + const optionKey = ratioType + .split('_') + .map(word => word.charAt(0).toUpperCase() + word.slice(1)) + .join(''); + currentRatios[optionKey][model] = parseFloat(value); + }); + }); + + // 保存到后端 + setLoading(true); + try { + const updates = Object.entries(currentRatios).map(([key, value]) => + API.put('/api/option/', { + key, + value: JSON.stringify(value, null, 2), + }) + ); + + const results = await Promise.all(updates); + + if (results.every(res => res.data.success)) { + showSuccess(t('同步成功')); + props.refresh(); + // 清空状态 + setDifferences({}); + setTestResults([]); + setResolutions({}); + setSelectedChannelIds([]); + } else { + showError(t('部分保存失败')); + } + } catch (error) { + showError(t('保存失败')); + } finally { + setLoading(false); + } + }; + + // 计算当前页显示的数据 + const getCurrentPageData = (dataSource) => { + const startIndex = (currentPage - 1) * pageSize; + const endIndex = startIndex + pageSize; + return dataSource.slice(startIndex, endIndex); + }; + + // 渲染表格头部 + const renderHeader = () => ( +
+
+
+ + + {(() => { + // 检查是否有选择可应用的值 + const hasSelections = Object.keys(resolutions).length > 0; + + return ( + + ); + })()} +
+
+
+ ); + + // 渲染差异表格 + const renderDifferenceTable = () => { + // 构建数据源 + const dataSource = useMemo(() => { + const tmp = []; + + Object.entries(differences).forEach(([model, ratioTypes]) => { + Object.entries(ratioTypes).forEach(([ratioType, diff]) => { + tmp.push({ + key: `${model}_${ratioType}`, + model, + ratioType, + current: diff.current, + upstreams: diff.upstreams, + }); + }); + }); + + return tmp; + }, [differences]); + + // 收集所有上游渠道名称 + const upstreamNames = useMemo(() => { + const set = new Set(); + dataSource.forEach((row) => { + Object.keys(row.upstreams || {}).forEach((name) => set.add(name)); + }); + return Array.from(set); + }, [dataSource]); + + if (dataSource.length === 0) { + return ( + } + darkModeImage={} + description={Object.keys(differences).length === 0 ? t('已与上游倍率完全一致') : t('请先选择同步渠道')} + style={{ padding: 30 }} + /> + ); + } + + // 列定义 + const columns = [ + { + title: t('模型'), + dataIndex: 'model', + fixed: 'left', + width: 160, + }, + { + title: t('倍率类型'), + dataIndex: 'ratioType', + width: 140, + render: (text) => { + const typeMap = { + model_ratio: t('模型倍率'), + completion_ratio: t('补全倍率'), + cache_ratio: t('缓存倍率'), + model_price: t('固定价格'), + }; + return {typeMap[text] || text}; + }, + }, + { + title: t('当前值'), + dataIndex: 'current', + width: 100, + render: (text) => ( + + {text !== null && text !== undefined ? text : t('未设置')} + + ), + }, + // 动态上游列 + ...upstreamNames.map((upName) => { + // 计算该渠道的全选状态 + const channelStats = (() => { + let selectableCount = 0; // 可选择的项目数量 + let selectedCount = 0; // 已选择的项目数量 + + dataSource.forEach((row) => { + const upstreamVal = row.upstreams?.[upName]; + // 只有具体数值的才是可选择的(不是null、undefined或"same") + if (upstreamVal !== null && upstreamVal !== undefined && upstreamVal !== 'same') { + selectableCount++; + const isSelected = resolutions[row.model]?.[row.ratioType] === upstreamVal; + if (isSelected) { + selectedCount++; + } + } + }); + + return { + selectableCount, + selectedCount, + allSelected: selectableCount > 0 && selectedCount === selectableCount, + partiallySelected: selectedCount > 0 && selectedCount < selectableCount, + hasSelectableItems: selectableCount > 0 + }; + })(); + + // 处理全选/取消全选 + const handleBulkSelect = (checked) => { + setResolutions((prev) => { + const newRes = { ...prev }; + + dataSource.forEach((row) => { + const upstreamVal = row.upstreams?.[upName]; + if (upstreamVal !== null && upstreamVal !== undefined && upstreamVal !== 'same') { + if (checked) { + // 选择该值 + if (!newRes[row.model]) newRes[row.model] = {}; + newRes[row.model][row.ratioType] = upstreamVal; + } else { + // 取消选择该值 + if (newRes[row.model]) { + delete newRes[row.model][row.ratioType]; + if (Object.keys(newRes[row.model]).length === 0) { + delete newRes[row.model]; + } + } + } + } + }); + + return newRes; + }); + }; + + return { + title: channelStats.hasSelectableItems ? ( + handleBulkSelect(e.target.checked)} + > + {upName} + + ) : ( + {upName} + ), + dataIndex: upName, + width: 140, + render: (_, record) => { + const upstreamVal = record.upstreams?.[upName]; + + if (upstreamVal === null || upstreamVal === undefined) { + return {t('未设置')}; + } + + if (upstreamVal === 'same') { + return {t('与本地相同')}; + } + + // 有具体值,可以选择 + const isSelected = resolutions[record.model]?.[record.ratioType] === upstreamVal; + + return ( + { + const isChecked = e.target.checked; + if (isChecked) { + selectValue(record.model, record.ratioType, upstreamVal); + } else { + setResolutions((prev) => { + const newRes = { ...prev }; + if (newRes[record.model]) { + delete newRes[record.model][record.ratioType]; + if (Object.keys(newRes[record.model]).length === 0) { + delete newRes[record.model]; + } + } + return newRes; + }); + } + }} + > + {upstreamVal} + + ); + }, + }; + }), + ]; + + return ( + t('第 {{start}} - {{end}} 条,共 {{total}} 条', { + start: page.currentStart, + end: page.currentEnd, + total: dataSource.length, + }), + pageSizeOptions: ['5', '10', '20', '50'], + onChange: (page, size) => { + setCurrentPage(page); + setPageSize(size); + }, + onShowSizeChange: (current, size) => { + setCurrentPage(1); + setPageSize(size); + } + }} + scroll={{ x: 'max-content' }} + size='middle' + loading={loading || syncLoading} + className="rounded-xl overflow-hidden" + /> + ); + }; + + // 更新渠道端点 + const updateChannelEndpoint = useCallback((channelId, endpoint) => { + setChannelEndpoints(prev => ({ ...prev, [channelId]: endpoint })); + }, []); + + return ( + <> + + {renderDifferenceTable()} + + + setModalVisible(false)} + onOk={confirmChannelSelection} + allChannels={allChannels} + selectedChannelIds={selectedChannelIds} + setSelectedChannelIds={setSelectedChannelIds} + customUrl={customUrl} + setCustomUrl={setCustomUrl} + customEndpoint={customEndpoint} + setCustomEndpoint={setCustomEndpoint} + customChannelTesting={customChannelTesting} + addCustomChannel={addCustomChannel} + channelEndpoints={channelEndpoints} + updateChannelEndpoint={updateChannelEndpoint} + /> + + ); +} \ No newline at end of file