diff --git a/common/utils.go b/common/utils.go index d9db67d0..17aecd95 100644 --- a/common/utils.go +++ b/common/utils.go @@ -13,6 +13,7 @@ import ( "math/big" "math/rand" "net" + "net/url" "os" "os/exec" "runtime" @@ -284,3 +285,20 @@ func GetAudioDuration(ctx context.Context, filename string, ext string) (float64 } return strconv.ParseFloat(durationStr, 64) } + +// BuildURL concatenates base and endpoint, returns the complete url string +func BuildURL(base string, endpoint string) string { + u, err := url.Parse(base) + if err != nil { + return base + endpoint + } + end := endpoint + if end == "" { + end = "/" + } + ref, err := url.Parse(end) + if err != nil { + return base + endpoint + } + return u.ResolveReference(ref).String() +} diff --git a/controller/ratio_config.go b/controller/ratio_config.go new file mode 100644 index 00000000..6ddc3d9e --- /dev/null +++ b/controller/ratio_config.go @@ -0,0 +1,24 @@ +package controller + +import ( + "net/http" + "one-api/setting/ratio_setting" + + "github.com/gin-gonic/gin" +) + +func GetRatioConfig(c *gin.Context) { + if !ratio_setting.IsExposeRatioEnabled() { + c.JSON(http.StatusForbidden, gin.H{ + "success": false, + "message": "倍率配置接口未启用", + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": ratio_setting.GetExposedData(), + }) +} \ No newline at end of file diff --git a/controller/ratio_sync.go b/controller/ratio_sync.go new file mode 100644 index 00000000..c7494b5b --- /dev/null +++ b/controller/ratio_sync.go @@ -0,0 +1,290 @@ +package controller + +import ( + "encoding/json" + "net/http" + "one-api/model" + "one-api/setting/ratio_setting" + "one-api/dto" + "sync" + "time" + + "github.com/gin-gonic/gin" +) + +type upstreamResult struct { + Name string `json:"name"` + Data map[string]any `json:"data,omitempty"` + Err string `json:"err,omitempty"` +} + +type TestResult struct { + Name string `json:"name"` + Status string `json:"status"` + Error string `json:"error,omitempty"` +} + +type DifferenceItem struct { + Current interface{} `json:"current"` // 当前本地值,可能为null + Upstreams map[string]interface{} `json:"upstreams"` // 上游值:具体值/"same"/null +} + +// SyncableChannel 可同步的渠道信息 +type SyncableChannel struct { + ID int `json:"id"` + Name string `json:"name"` + BaseURL string `json:"base_url"` + Status int `json:"status"` +} + +// FetchUpstreamRatios 后端并发拉取上游倍率 +func FetchUpstreamRatios(c *gin.Context) { + var req dto.UpstreamRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()}) + return + } + + if req.Timeout <= 0 { + req.Timeout = 10 + } + + // build upstream list from ids + custom + var upstreams []dto.UpstreamDTO + if len(req.ChannelIDs) > 0 { + // convert []int64 -> []int for model function + intIds := make([]int, 0, len(req.ChannelIDs)) + for _, id64 := range req.ChannelIDs { + intIds = append(intIds, int(id64)) + } + dbChannels, _ := model.GetChannelsByIds(intIds) + for _, ch := range dbChannels { + upstreams = append(upstreams, dto.UpstreamDTO{ + Name: ch.Name, + BaseURL: ch.GetBaseURL(), + Endpoint: "", // assume default endpoint + }) + } + } + upstreams = append(upstreams, req.CustomChannels...) + + var wg sync.WaitGroup + ch := make(chan upstreamResult, len(upstreams)) + + for _, chn := range upstreams { + wg.Add(1) + go func(chItem dto.UpstreamDTO) { + defer wg.Done() + endpoint := chItem.Endpoint + if endpoint == "" { + endpoint = "/api/ratio_config" + } + url := chItem.BaseURL + endpoint + client := http.Client{Timeout: time.Duration(req.Timeout) * time.Second} + resp, err := client.Get(url) + if err != nil { + ch <- upstreamResult{Name: chItem.Name, Err: err.Error()} + return + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + ch <- upstreamResult{Name: chItem.Name, Err: resp.Status} + return + } + var body struct { + Success bool `json:"success"` + Data map[string]any `json:"data"` + Message string `json:"message"` + } + if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { + ch <- upstreamResult{Name: chItem.Name, Err: err.Error()} + return + } + if !body.Success { + ch <- upstreamResult{Name: chItem.Name, Err: body.Message} + return + } + ch <- upstreamResult{Name: chItem.Name, Data: body.Data} + }(chn) + } + + wg.Wait() + close(ch) + + // 本地倍率配置 + localData := ratio_setting.GetExposedData() + + var testResults []dto.TestResult + var successfulChannels []struct { + name string + data map[string]any + } + + for r := range ch { + if r.Err != "" { + testResults = append(testResults, dto.TestResult{ + Name: r.Name, + Status: "error", + Error: r.Err, + }) + } else { + testResults = append(testResults, dto.TestResult{ + Name: r.Name, + Status: "success", + }) + successfulChannels = append(successfulChannels, struct { + name string + data map[string]any + }{name: r.Name, data: r.Data}) + } + } + + // 构建差异化数据 + differences := buildDifferences(localData, successfulChannels) + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "data": gin.H{ + "differences": differences, + "test_results": testResults, + }, + }) +} + +// buildDifferences 构建差异化数据,只返回有意义的差异 +func buildDifferences(localData map[string]any, successfulChannels []struct { + name string + data map[string]any +}) map[string]map[string]dto.DifferenceItem { + differences := make(map[string]map[string]dto.DifferenceItem) + ratioTypes := []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"} + + // 收集所有模型名称 + allModels := make(map[string]struct{}) + + // 从本地数据收集模型名称 + for _, ratioType := range ratioTypes { + if localRatioAny, ok := localData[ratioType]; ok { + if localRatio, ok := localRatioAny.(map[string]float64); ok { + for modelName := range localRatio { + allModels[modelName] = struct{}{} + } + } + } + } + + // 从上游数据收集模型名称 + for _, channel := range successfulChannels { + for _, ratioType := range ratioTypes { + if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok { + for modelName := range upstreamRatio { + allModels[modelName] = struct{}{} + } + } + } + } + + // 对每个模型和每个比率类型进行分析 + for modelName := range allModels { + for _, ratioType := range ratioTypes { + // 获取本地值 + var localValue interface{} = nil + if localRatioAny, ok := localData[ratioType]; ok { + if localRatio, ok := localRatioAny.(map[string]float64); ok { + if val, exists := localRatio[modelName]; exists { + localValue = val + } + } + } + + // 收集上游值 + upstreamValues := make(map[string]interface{}) + hasUpstreamValue := false + hasDifference := false + + for _, channel := range successfulChannels { + var upstreamValue interface{} = nil + + if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok { + if val, exists := upstreamRatio[modelName]; exists { + upstreamValue = val + hasUpstreamValue = true + + // 检查是否与本地值不同 + if localValue != nil && localValue != val { + hasDifference = true + } else if localValue == val { + upstreamValue = "same" + } + } + } + + // 如果本地值为空但上游有值,这也是差异 + if localValue == nil && upstreamValue != nil && upstreamValue != "same" { + hasDifference = true + } + + upstreamValues[channel.name] = upstreamValue + } + + // 应用过滤逻辑 + shouldInclude := false + + if localValue != nil { + // 规则1: 本地值存在,至少有一个上游与本地值不同 + if hasDifference { + shouldInclude = true + } + // 规则2: 本地值存在,但所有上游都未设置 - 不包含 + } else { + // 规则3: 本地值不存在,至少有一个上游设置了值 + if hasUpstreamValue { + shouldInclude = true + } + } + + if shouldInclude { + if differences[modelName] == nil { + differences[modelName] = make(map[string]dto.DifferenceItem) + } + differences[modelName][ratioType] = dto.DifferenceItem{ + Current: localValue, + Upstreams: upstreamValues, + } + } + } + } + + return differences +} + +// GetSyncableChannels 获取可用于倍率同步的渠道(base_url 不为空的渠道) +func GetSyncableChannels(c *gin.Context) { + channels, err := model.GetAllChannels(0, 0, true, false) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + var syncableChannels []dto.SyncableChannel + for _, channel := range channels { + // 只返回 base_url 不为空的渠道 + if channel.GetBaseURL() != "" { + syncableChannels = append(syncableChannels, dto.SyncableChannel{ + ID: channel.Id, + Name: channel.Name, + BaseURL: channel.GetBaseURL(), + Status: channel.Status, + }) + } + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": syncableChannels, + }) +} \ No newline at end of file diff --git a/dto/ratio_sync.go b/dto/ratio_sync.go new file mode 100644 index 00000000..4f2fe06d --- /dev/null +++ b/dto/ratio_sync.go @@ -0,0 +1,50 @@ +package dto + +// UpstreamDTO 提交到后端同步倍率的上游渠道信息 +// Endpoint 可以为空,后端会默认使用 /api/ratio_config +// BaseURL 必须以 http/https 开头,不要以 / 结尾 +// 例如: https://api.example.com +// Endpoint: /api/ratio_config +// 提交示例: +// { +// "name": "openai", +// "base_url": "https://api.openai.com", +// "endpoint": "/ratio_config" +// } + +type UpstreamDTO struct { + Name string `json:"name" binding:"required"` + BaseURL string `json:"base_url" binding:"required"` + Endpoint string `json:"endpoint"` +} + +type UpstreamRequest struct { + ChannelIDs []int64 `json:"channel_ids"` + CustomChannels []UpstreamDTO `json:"custom_channels"` + Timeout int `json:"timeout"` +} + +// TestResult 上游测试连通性结果 +type TestResult struct { + Name string `json:"name"` + Status string `json:"status"` + Error string `json:"error,omitempty"` +} + +// DifferenceItem 差异项 +// Current 为本地值,可能为 nil +// Upstreams 为各渠道的上游值,具体数值 / "same" / nil + +type DifferenceItem struct { + Current interface{} `json:"current"` + Upstreams map[string]interface{} `json:"upstreams"` +} + +// SyncableChannel 可同步的渠道信息(base_url 不为空) + +type SyncableChannel struct { + ID int `json:"id"` + Name string `json:"name"` + BaseURL string `json:"base_url"` + Status int `json:"status"` +} \ No newline at end of file diff --git a/model/option.go b/model/option.go index 43c0a644..97f7baae 100644 --- a/model/option.go +++ b/model/option.go @@ -126,6 +126,7 @@ func InitOptionMap() { common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString() common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength) common.OptionMap["AutomaticDisableKeywords"] = operation_setting.AutomaticDisableKeywordsToString() + common.OptionMap["ExposeRatioEnabled"] = strconv.FormatBool(ratio_setting.IsExposeRatioEnabled()) // 自动添加所有注册的模型配置 modelConfigs := config.GlobalConfig.ExportAllConfigs() @@ -266,6 +267,8 @@ func updateOptionMap(key string, value string) (err error) { setting.WorkerAllowHttpImageRequestEnabled = boolValue case "DefaultUseAutoGroup": setting.DefaultUseAutoGroup = boolValue + case "ExposeRatioEnabled": + ratio_setting.SetExposeRatioEnabled(boolValue) } } switch key { diff --git a/router/api-router.go b/router/api-router.go index 45930246..badfa7bf 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -36,6 +36,7 @@ func SetApiRouter(router *gin.Engine) { apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), controller.EmailBind) apiRouter.GET("/oauth/telegram/login", middleware.CriticalRateLimit(), controller.TelegramLogin) apiRouter.GET("/oauth/telegram/bind", middleware.CriticalRateLimit(), controller.TelegramBind) + apiRouter.GET("/ratio_config", middleware.CriticalRateLimit(), controller.GetRatioConfig) userRoute := apiRouter.Group("/user") { @@ -83,6 +84,12 @@ func SetApiRouter(router *gin.Engine) { optionRoute.POST("/rest_model_ratio", controller.ResetModelRatio) optionRoute.POST("/migrate_console_setting", controller.MigrateConsoleSetting) // 用于迁移检测的旧键,下个版本会删除 } + ratioSyncRoute := apiRouter.Group("/ratio_sync") + ratioSyncRoute.Use(middleware.RootAuth()) + { + ratioSyncRoute.GET("/channels", controller.GetSyncableChannels) + ratioSyncRoute.POST("/fetch", controller.FetchUpstreamRatios) + } channelRoute := apiRouter.Group("/channel") channelRoute.Use(middleware.AdminAuth()) { diff --git a/setting/ratio_setting/cache_ratio.go b/setting/ratio_setting/cache_ratio.go index aa934b22..51d473a8 100644 --- a/setting/ratio_setting/cache_ratio.go +++ b/setting/ratio_setting/cache_ratio.go @@ -85,7 +85,11 @@ func UpdateCacheRatioByJSONString(jsonStr string) error { cacheRatioMapMutex.Lock() defer cacheRatioMapMutex.Unlock() cacheRatioMap = make(map[string]float64) - return json.Unmarshal([]byte(jsonStr), &cacheRatioMap) + err := json.Unmarshal([]byte(jsonStr), &cacheRatioMap) + if err == nil { + InvalidateExposedDataCache() + } + return err } // GetCacheRatio returns the cache ratio for a model @@ -106,3 +110,13 @@ func GetCreateCacheRatio(name string) (float64, bool) { } return ratio, true } + +func GetCacheRatioCopy() map[string]float64 { + cacheRatioMapMutex.RLock() + defer cacheRatioMapMutex.RUnlock() + copyMap := make(map[string]float64, len(cacheRatioMap)) + for k, v := range cacheRatioMap { + copyMap[k] = v + } + return copyMap +} diff --git a/setting/ratio_setting/expose_ratio.go b/setting/ratio_setting/expose_ratio.go new file mode 100644 index 00000000..8fca0bcb --- /dev/null +++ b/setting/ratio_setting/expose_ratio.go @@ -0,0 +1,17 @@ +package ratio_setting + +import "sync/atomic" + +var exposeRatioEnabled atomic.Bool + +func init() { + exposeRatioEnabled.Store(false) +} + +func SetExposeRatioEnabled(enabled bool) { + exposeRatioEnabled.Store(enabled) +} + +func IsExposeRatioEnabled() bool { + return exposeRatioEnabled.Load() +} \ No newline at end of file diff --git a/setting/ratio_setting/exposed_cache.go b/setting/ratio_setting/exposed_cache.go new file mode 100644 index 00000000..9e5b6c30 --- /dev/null +++ b/setting/ratio_setting/exposed_cache.go @@ -0,0 +1,55 @@ +package ratio_setting + +import ( + "sync" + "sync/atomic" + "time" + + "github.com/gin-gonic/gin" +) + +const exposedDataTTL = 30 * time.Second + +type exposedCache struct { + data gin.H + expiresAt time.Time +} + +var ( + exposedData atomic.Value + rebuildMu sync.Mutex +) + +func InvalidateExposedDataCache() { + exposedData.Store((*exposedCache)(nil)) +} + +func cloneGinH(src gin.H) gin.H { + dst := make(gin.H, len(src)) + for k, v := range src { + dst[k] = v + } + return dst +} + +func GetExposedData() gin.H { + if c, ok := exposedData.Load().(*exposedCache); ok && c != nil && time.Now().Before(c.expiresAt) { + return cloneGinH(c.data) + } + rebuildMu.Lock() + defer rebuildMu.Unlock() + if c, ok := exposedData.Load().(*exposedCache); ok && c != nil && time.Now().Before(c.expiresAt) { + return cloneGinH(c.data) + } + newData := gin.H{ + "model_ratio": GetModelRatioCopy(), + "completion_ratio": GetCompletionRatioCopy(), + "cache_ratio": GetCacheRatioCopy(), + "model_price": GetModelPriceCopy(), + } + exposedData.Store(&exposedCache{ + data: newData, + expiresAt: time.Now().Add(exposedDataTTL), + }) + return cloneGinH(newData) +} \ No newline at end of file diff --git a/setting/ratio_setting/model_ratio.go b/setting/ratio_setting/model_ratio.go index 3102dfe9..1eaf25b1 100644 --- a/setting/ratio_setting/model_ratio.go +++ b/setting/ratio_setting/model_ratio.go @@ -317,7 +317,11 @@ func UpdateModelPriceByJSONString(jsonStr string) error { modelPriceMapMutex.Lock() defer modelPriceMapMutex.Unlock() modelPriceMap = make(map[string]float64) - return json.Unmarshal([]byte(jsonStr), &modelPriceMap) + err := json.Unmarshal([]byte(jsonStr), &modelPriceMap) + if err == nil { + InvalidateExposedDataCache() + } + return err } // GetModelPrice 返回模型的价格,如果模型不存在则返回-1,false @@ -345,7 +349,11 @@ func UpdateModelRatioByJSONString(jsonStr string) error { modelRatioMapMutex.Lock() defer modelRatioMapMutex.Unlock() modelRatioMap = make(map[string]float64) - return json.Unmarshal([]byte(jsonStr), &modelRatioMap) + err := json.Unmarshal([]byte(jsonStr), &modelRatioMap) + if err == nil { + InvalidateExposedDataCache() + } + return err } // 处理带有思考预算的模型名称,方便统一定价 @@ -405,7 +413,11 @@ func UpdateCompletionRatioByJSONString(jsonStr string) error { CompletionRatioMutex.Lock() defer CompletionRatioMutex.Unlock() CompletionRatio = make(map[string]float64) - return json.Unmarshal([]byte(jsonStr), &CompletionRatio) + err := json.Unmarshal([]byte(jsonStr), &CompletionRatio) + if err == nil { + InvalidateExposedDataCache() + } + return err } func GetCompletionRatio(name string) float64 { @@ -609,3 +621,33 @@ func GetImageRatio(name string) (float64, bool) { } return ratio, true } + +func GetModelRatioCopy() map[string]float64 { + modelRatioMapMutex.RLock() + defer modelRatioMapMutex.RUnlock() + copyMap := make(map[string]float64, len(modelRatioMap)) + for k, v := range modelRatioMap { + copyMap[k] = v + } + return copyMap +} + +func GetModelPriceCopy() map[string]float64 { + modelPriceMapMutex.RLock() + defer modelPriceMapMutex.RUnlock() + copyMap := make(map[string]float64, len(modelPriceMap)) + for k, v := range modelPriceMap { + copyMap[k] = v + } + return copyMap +} + +func GetCompletionRatioCopy() map[string]float64 { + CompletionRatioMutex.RLock() + defer CompletionRatioMutex.RUnlock() + copyMap := make(map[string]float64, len(CompletionRatio)) + for k, v := range CompletionRatio { + copyMap[k] = v + } + return copyMap +} diff --git a/web/src/components/settings/ChannelSelectorModal.js b/web/src/components/settings/ChannelSelectorModal.js new file mode 100644 index 00000000..c393d97f --- /dev/null +++ b/web/src/components/settings/ChannelSelectorModal.js @@ -0,0 +1,154 @@ +import React from 'react'; +import { + Modal, + Transfer, + Input, + Card, + Space, + Button, + Checkbox, +} from '@douyinfe/semi-ui'; +import { IconPlus, IconClose } from '@douyinfe/semi-icons'; + +/** + * ChannelSelectorModal + * 负责选择同步渠道、测试与批量测试等 UI,纯展示组件。 + * 业务状态与动作通过 props 注入,保持可复用与可测试。 + */ +export default function ChannelSelectorModal({ + t, + visible, + onCancel, + onOk, + // 渠道与选择 + allChannels = [], + selectedChannelIds = [], + setSelectedChannelIds, + // 自定义渠道 + customUrl, + setCustomUrl, + customEndpoint, + setCustomEndpoint, + customChannelTesting, + addCustomChannel, + // 渠道端点 + channelEndpoints, + updateChannelEndpoint, + // 测试相关 +}) { + // Transfer 自定义渲染 + const renderSourceItem = (item) => { + const channelId = item.key || item.value; + const currentEndpoint = channelEndpoints[channelId]; + const baseUrl = item._originalData?.base_url || ''; + + return ( +