From 7975cdf3bf6dd19e4db18c90861148bd5ada1877 Mon Sep 17 00:00:00 2001 From: "Apple\\Apple" Date: Thu, 19 Jun 2025 08:57:34 +0800 Subject: [PATCH 01/19] =?UTF-8?q?=F0=9F=9A=80=20feat(ratio-sync):=20major?= =?UTF-8?q?=20refactor=20&=20UX=20overhaul=20for=20Upstream=20Ratio=20Sync?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/utils.go | 18 + controller/ratio_config.go | 24 + controller/ratio_sync.go | 290 +++++++++ dto/ratio_sync.go | 50 ++ model/option.go | 3 + router/api-router.go | 7 + setting/ratio_setting/cache_ratio.go | 16 +- setting/ratio_setting/expose_ratio.go | 17 + setting/ratio_setting/exposed_cache.go | 55 ++ setting/ratio_setting/model_ratio.go | 48 +- .../settings/ChannelSelectorModal.js | 154 +++++ web/src/components/settings/RatioSetting.js | 18 +- web/src/helpers/ratio.js | 20 + web/src/pages/Detail/index.js | 2 - .../pages/Setting/Ratio/ModelRatioSettings.js | 12 + .../pages/Setting/Ratio/UpstreamRatioSync.js | 596 ++++++++++++++++++ 16 files changed, 1319 insertions(+), 11 deletions(-) create mode 100644 controller/ratio_config.go create mode 100644 controller/ratio_sync.go create mode 100644 dto/ratio_sync.go create mode 100644 setting/ratio_setting/expose_ratio.go create mode 100644 setting/ratio_setting/exposed_cache.go create mode 100644 web/src/components/settings/ChannelSelectorModal.js create mode 100644 web/src/helpers/ratio.js create mode 100644 web/src/pages/Setting/Ratio/UpstreamRatioSync.js diff --git a/common/utils.go b/common/utils.go index d9db67d0..17aecd95 100644 --- a/common/utils.go +++ b/common/utils.go @@ -13,6 +13,7 @@ import ( "math/big" "math/rand" "net" + "net/url" "os" "os/exec" "runtime" @@ -284,3 +285,20 @@ func GetAudioDuration(ctx context.Context, filename string, ext string) (float64 } return strconv.ParseFloat(durationStr, 64) } + +// BuildURL concatenates base and endpoint, returns the complete url string +func BuildURL(base string, endpoint string) string { + u, err := url.Parse(base) + if err != nil { + return base + endpoint + } + end := endpoint + if end == "" { + end = "/" + } + ref, err := url.Parse(end) + if err != nil { + return base + endpoint + } + return u.ResolveReference(ref).String() +} diff --git a/controller/ratio_config.go b/controller/ratio_config.go new file mode 100644 index 00000000..6ddc3d9e --- /dev/null +++ b/controller/ratio_config.go @@ -0,0 +1,24 @@ +package controller + +import ( + "net/http" + "one-api/setting/ratio_setting" + + "github.com/gin-gonic/gin" +) + +func GetRatioConfig(c *gin.Context) { + if !ratio_setting.IsExposeRatioEnabled() { + c.JSON(http.StatusForbidden, gin.H{ + "success": false, + "message": "倍率配置接口未启用", + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": ratio_setting.GetExposedData(), + }) +} \ No newline at end of file diff --git a/controller/ratio_sync.go b/controller/ratio_sync.go new file mode 100644 index 00000000..c7494b5b --- /dev/null +++ b/controller/ratio_sync.go @@ -0,0 +1,290 @@ +package controller + +import ( + "encoding/json" + "net/http" + "one-api/model" + "one-api/setting/ratio_setting" + "one-api/dto" + "sync" + "time" + + "github.com/gin-gonic/gin" +) + +type upstreamResult struct { + Name string `json:"name"` + Data map[string]any `json:"data,omitempty"` + Err string `json:"err,omitempty"` +} + +type TestResult struct { + Name string `json:"name"` + Status string `json:"status"` + Error string `json:"error,omitempty"` +} + +type DifferenceItem struct { + Current interface{} `json:"current"` // 当前本地值,可能为null + Upstreams map[string]interface{} `json:"upstreams"` // 上游值:具体值/"same"/null +} + +// SyncableChannel 可同步的渠道信息 +type SyncableChannel struct { + ID int `json:"id"` + Name string `json:"name"` + BaseURL string `json:"base_url"` + Status int `json:"status"` +} + +// FetchUpstreamRatios 后端并发拉取上游倍率 +func FetchUpstreamRatios(c *gin.Context) { + var req dto.UpstreamRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()}) + return + } + + if req.Timeout <= 0 { + req.Timeout = 10 + } + + // build upstream list from ids + custom + var upstreams []dto.UpstreamDTO + if len(req.ChannelIDs) > 0 { + // convert []int64 -> []int for model function + intIds := make([]int, 0, len(req.ChannelIDs)) + for _, id64 := range req.ChannelIDs { + intIds = append(intIds, int(id64)) + } + dbChannels, _ := model.GetChannelsByIds(intIds) + for _, ch := range dbChannels { + upstreams = append(upstreams, dto.UpstreamDTO{ + Name: ch.Name, + BaseURL: ch.GetBaseURL(), + Endpoint: "", // assume default endpoint + }) + } + } + upstreams = append(upstreams, req.CustomChannels...) + + var wg sync.WaitGroup + ch := make(chan upstreamResult, len(upstreams)) + + for _, chn := range upstreams { + wg.Add(1) + go func(chItem dto.UpstreamDTO) { + defer wg.Done() + endpoint := chItem.Endpoint + if endpoint == "" { + endpoint = "/api/ratio_config" + } + url := chItem.BaseURL + endpoint + client := http.Client{Timeout: time.Duration(req.Timeout) * time.Second} + resp, err := client.Get(url) + if err != nil { + ch <- upstreamResult{Name: chItem.Name, Err: err.Error()} + return + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + ch <- upstreamResult{Name: chItem.Name, Err: resp.Status} + return + } + var body struct { + Success bool `json:"success"` + Data map[string]any `json:"data"` + Message string `json:"message"` + } + if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { + ch <- upstreamResult{Name: chItem.Name, Err: err.Error()} + return + } + if !body.Success { + ch <- upstreamResult{Name: chItem.Name, Err: body.Message} + return + } + ch <- upstreamResult{Name: chItem.Name, Data: body.Data} + }(chn) + } + + wg.Wait() + close(ch) + + // 本地倍率配置 + localData := ratio_setting.GetExposedData() + + var testResults []dto.TestResult + var successfulChannels []struct { + name string + data map[string]any + } + + for r := range ch { + if r.Err != "" { + testResults = append(testResults, dto.TestResult{ + Name: r.Name, + Status: "error", + Error: r.Err, + }) + } else { + testResults = append(testResults, dto.TestResult{ + Name: r.Name, + Status: "success", + }) + successfulChannels = append(successfulChannels, struct { + name string + data map[string]any + }{name: r.Name, data: r.Data}) + } + } + + // 构建差异化数据 + differences := buildDifferences(localData, successfulChannels) + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "data": gin.H{ + "differences": differences, + "test_results": testResults, + }, + }) +} + +// buildDifferences 构建差异化数据,只返回有意义的差异 +func buildDifferences(localData map[string]any, successfulChannels []struct { + name string + data map[string]any +}) map[string]map[string]dto.DifferenceItem { + differences := make(map[string]map[string]dto.DifferenceItem) + ratioTypes := []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"} + + // 收集所有模型名称 + allModels := make(map[string]struct{}) + + // 从本地数据收集模型名称 + for _, ratioType := range ratioTypes { + if localRatioAny, ok := localData[ratioType]; ok { + if localRatio, ok := localRatioAny.(map[string]float64); ok { + for modelName := range localRatio { + allModels[modelName] = struct{}{} + } + } + } + } + + // 从上游数据收集模型名称 + for _, channel := range successfulChannels { + for _, ratioType := range ratioTypes { + if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok { + for modelName := range upstreamRatio { + allModels[modelName] = struct{}{} + } + } + } + } + + // 对每个模型和每个比率类型进行分析 + for modelName := range allModels { + for _, ratioType := range ratioTypes { + // 获取本地值 + var localValue interface{} = nil + if localRatioAny, ok := localData[ratioType]; ok { + if localRatio, ok := localRatioAny.(map[string]float64); ok { + if val, exists := localRatio[modelName]; exists { + localValue = val + } + } + } + + // 收集上游值 + upstreamValues := make(map[string]interface{}) + hasUpstreamValue := false + hasDifference := false + + for _, channel := range successfulChannels { + var upstreamValue interface{} = nil + + if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok { + if val, exists := upstreamRatio[modelName]; exists { + upstreamValue = val + hasUpstreamValue = true + + // 检查是否与本地值不同 + if localValue != nil && localValue != val { + hasDifference = true + } else if localValue == val { + upstreamValue = "same" + } + } + } + + // 如果本地值为空但上游有值,这也是差异 + if localValue == nil && upstreamValue != nil && upstreamValue != "same" { + hasDifference = true + } + + upstreamValues[channel.name] = upstreamValue + } + + // 应用过滤逻辑 + shouldInclude := false + + if localValue != nil { + // 规则1: 本地值存在,至少有一个上游与本地值不同 + if hasDifference { + shouldInclude = true + } + // 规则2: 本地值存在,但所有上游都未设置 - 不包含 + } else { + // 规则3: 本地值不存在,至少有一个上游设置了值 + if hasUpstreamValue { + shouldInclude = true + } + } + + if shouldInclude { + if differences[modelName] == nil { + differences[modelName] = make(map[string]dto.DifferenceItem) + } + differences[modelName][ratioType] = dto.DifferenceItem{ + Current: localValue, + Upstreams: upstreamValues, + } + } + } + } + + return differences +} + +// GetSyncableChannels 获取可用于倍率同步的渠道(base_url 不为空的渠道) +func GetSyncableChannels(c *gin.Context) { + channels, err := model.GetAllChannels(0, 0, true, false) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + var syncableChannels []dto.SyncableChannel + for _, channel := range channels { + // 只返回 base_url 不为空的渠道 + if channel.GetBaseURL() != "" { + syncableChannels = append(syncableChannels, dto.SyncableChannel{ + ID: channel.Id, + Name: channel.Name, + BaseURL: channel.GetBaseURL(), + Status: channel.Status, + }) + } + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": syncableChannels, + }) +} \ No newline at end of file diff --git a/dto/ratio_sync.go b/dto/ratio_sync.go new file mode 100644 index 00000000..4f2fe06d --- /dev/null +++ b/dto/ratio_sync.go @@ -0,0 +1,50 @@ +package dto + +// UpstreamDTO 提交到后端同步倍率的上游渠道信息 +// Endpoint 可以为空,后端会默认使用 /api/ratio_config +// BaseURL 必须以 http/https 开头,不要以 / 结尾 +// 例如: https://api.example.com +// Endpoint: /api/ratio_config +// 提交示例: +// { +// "name": "openai", +// "base_url": "https://api.openai.com", +// "endpoint": "/ratio_config" +// } + +type UpstreamDTO struct { + Name string `json:"name" binding:"required"` + BaseURL string `json:"base_url" binding:"required"` + Endpoint string `json:"endpoint"` +} + +type UpstreamRequest struct { + ChannelIDs []int64 `json:"channel_ids"` + CustomChannels []UpstreamDTO `json:"custom_channels"` + Timeout int `json:"timeout"` +} + +// TestResult 上游测试连通性结果 +type TestResult struct { + Name string `json:"name"` + Status string `json:"status"` + Error string `json:"error,omitempty"` +} + +// DifferenceItem 差异项 +// Current 为本地值,可能为 nil +// Upstreams 为各渠道的上游值,具体数值 / "same" / nil + +type DifferenceItem struct { + Current interface{} `json:"current"` + Upstreams map[string]interface{} `json:"upstreams"` +} + +// SyncableChannel 可同步的渠道信息(base_url 不为空) + +type SyncableChannel struct { + ID int `json:"id"` + Name string `json:"name"` + BaseURL string `json:"base_url"` + Status int `json:"status"` +} \ No newline at end of file diff --git a/model/option.go b/model/option.go index 43c0a644..97f7baae 100644 --- a/model/option.go +++ b/model/option.go @@ -126,6 +126,7 @@ func InitOptionMap() { common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString() common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength) common.OptionMap["AutomaticDisableKeywords"] = operation_setting.AutomaticDisableKeywordsToString() + common.OptionMap["ExposeRatioEnabled"] = strconv.FormatBool(ratio_setting.IsExposeRatioEnabled()) // 自动添加所有注册的模型配置 modelConfigs := config.GlobalConfig.ExportAllConfigs() @@ -266,6 +267,8 @@ func updateOptionMap(key string, value string) (err error) { setting.WorkerAllowHttpImageRequestEnabled = boolValue case "DefaultUseAutoGroup": setting.DefaultUseAutoGroup = boolValue + case "ExposeRatioEnabled": + ratio_setting.SetExposeRatioEnabled(boolValue) } } switch key { diff --git a/router/api-router.go b/router/api-router.go index 45930246..badfa7bf 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -36,6 +36,7 @@ func SetApiRouter(router *gin.Engine) { apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), controller.EmailBind) apiRouter.GET("/oauth/telegram/login", middleware.CriticalRateLimit(), controller.TelegramLogin) apiRouter.GET("/oauth/telegram/bind", middleware.CriticalRateLimit(), controller.TelegramBind) + apiRouter.GET("/ratio_config", middleware.CriticalRateLimit(), controller.GetRatioConfig) userRoute := apiRouter.Group("/user") { @@ -83,6 +84,12 @@ func SetApiRouter(router *gin.Engine) { optionRoute.POST("/rest_model_ratio", controller.ResetModelRatio) optionRoute.POST("/migrate_console_setting", controller.MigrateConsoleSetting) // 用于迁移检测的旧键,下个版本会删除 } + ratioSyncRoute := apiRouter.Group("/ratio_sync") + ratioSyncRoute.Use(middleware.RootAuth()) + { + ratioSyncRoute.GET("/channels", controller.GetSyncableChannels) + ratioSyncRoute.POST("/fetch", controller.FetchUpstreamRatios) + } channelRoute := apiRouter.Group("/channel") channelRoute.Use(middleware.AdminAuth()) { diff --git a/setting/ratio_setting/cache_ratio.go b/setting/ratio_setting/cache_ratio.go index aa934b22..51d473a8 100644 --- a/setting/ratio_setting/cache_ratio.go +++ b/setting/ratio_setting/cache_ratio.go @@ -85,7 +85,11 @@ func UpdateCacheRatioByJSONString(jsonStr string) error { cacheRatioMapMutex.Lock() defer cacheRatioMapMutex.Unlock() cacheRatioMap = make(map[string]float64) - return json.Unmarshal([]byte(jsonStr), &cacheRatioMap) + err := json.Unmarshal([]byte(jsonStr), &cacheRatioMap) + if err == nil { + InvalidateExposedDataCache() + } + return err } // GetCacheRatio returns the cache ratio for a model @@ -106,3 +110,13 @@ func GetCreateCacheRatio(name string) (float64, bool) { } return ratio, true } + +func GetCacheRatioCopy() map[string]float64 { + cacheRatioMapMutex.RLock() + defer cacheRatioMapMutex.RUnlock() + copyMap := make(map[string]float64, len(cacheRatioMap)) + for k, v := range cacheRatioMap { + copyMap[k] = v + } + return copyMap +} diff --git a/setting/ratio_setting/expose_ratio.go b/setting/ratio_setting/expose_ratio.go new file mode 100644 index 00000000..8fca0bcb --- /dev/null +++ b/setting/ratio_setting/expose_ratio.go @@ -0,0 +1,17 @@ +package ratio_setting + +import "sync/atomic" + +var exposeRatioEnabled atomic.Bool + +func init() { + exposeRatioEnabled.Store(false) +} + +func SetExposeRatioEnabled(enabled bool) { + exposeRatioEnabled.Store(enabled) +} + +func IsExposeRatioEnabled() bool { + return exposeRatioEnabled.Load() +} \ No newline at end of file diff --git a/setting/ratio_setting/exposed_cache.go b/setting/ratio_setting/exposed_cache.go new file mode 100644 index 00000000..9e5b6c30 --- /dev/null +++ b/setting/ratio_setting/exposed_cache.go @@ -0,0 +1,55 @@ +package ratio_setting + +import ( + "sync" + "sync/atomic" + "time" + + "github.com/gin-gonic/gin" +) + +const exposedDataTTL = 30 * time.Second + +type exposedCache struct { + data gin.H + expiresAt time.Time +} + +var ( + exposedData atomic.Value + rebuildMu sync.Mutex +) + +func InvalidateExposedDataCache() { + exposedData.Store((*exposedCache)(nil)) +} + +func cloneGinH(src gin.H) gin.H { + dst := make(gin.H, len(src)) + for k, v := range src { + dst[k] = v + } + return dst +} + +func GetExposedData() gin.H { + if c, ok := exposedData.Load().(*exposedCache); ok && c != nil && time.Now().Before(c.expiresAt) { + return cloneGinH(c.data) + } + rebuildMu.Lock() + defer rebuildMu.Unlock() + if c, ok := exposedData.Load().(*exposedCache); ok && c != nil && time.Now().Before(c.expiresAt) { + return cloneGinH(c.data) + } + newData := gin.H{ + "model_ratio": GetModelRatioCopy(), + "completion_ratio": GetCompletionRatioCopy(), + "cache_ratio": GetCacheRatioCopy(), + "model_price": GetModelPriceCopy(), + } + exposedData.Store(&exposedCache{ + data: newData, + expiresAt: time.Now().Add(exposedDataTTL), + }) + return cloneGinH(newData) +} \ No newline at end of file diff --git a/setting/ratio_setting/model_ratio.go b/setting/ratio_setting/model_ratio.go index 3102dfe9..1eaf25b1 100644 --- a/setting/ratio_setting/model_ratio.go +++ b/setting/ratio_setting/model_ratio.go @@ -317,7 +317,11 @@ func UpdateModelPriceByJSONString(jsonStr string) error { modelPriceMapMutex.Lock() defer modelPriceMapMutex.Unlock() modelPriceMap = make(map[string]float64) - return json.Unmarshal([]byte(jsonStr), &modelPriceMap) + err := json.Unmarshal([]byte(jsonStr), &modelPriceMap) + if err == nil { + InvalidateExposedDataCache() + } + return err } // GetModelPrice 返回模型的价格,如果模型不存在则返回-1,false @@ -345,7 +349,11 @@ func UpdateModelRatioByJSONString(jsonStr string) error { modelRatioMapMutex.Lock() defer modelRatioMapMutex.Unlock() modelRatioMap = make(map[string]float64) - return json.Unmarshal([]byte(jsonStr), &modelRatioMap) + err := json.Unmarshal([]byte(jsonStr), &modelRatioMap) + if err == nil { + InvalidateExposedDataCache() + } + return err } // 处理带有思考预算的模型名称,方便统一定价 @@ -405,7 +413,11 @@ func UpdateCompletionRatioByJSONString(jsonStr string) error { CompletionRatioMutex.Lock() defer CompletionRatioMutex.Unlock() CompletionRatio = make(map[string]float64) - return json.Unmarshal([]byte(jsonStr), &CompletionRatio) + err := json.Unmarshal([]byte(jsonStr), &CompletionRatio) + if err == nil { + InvalidateExposedDataCache() + } + return err } func GetCompletionRatio(name string) float64 { @@ -609,3 +621,33 @@ func GetImageRatio(name string) (float64, bool) { } return ratio, true } + +func GetModelRatioCopy() map[string]float64 { + modelRatioMapMutex.RLock() + defer modelRatioMapMutex.RUnlock() + copyMap := make(map[string]float64, len(modelRatioMap)) + for k, v := range modelRatioMap { + copyMap[k] = v + } + return copyMap +} + +func GetModelPriceCopy() map[string]float64 { + modelPriceMapMutex.RLock() + defer modelPriceMapMutex.RUnlock() + copyMap := make(map[string]float64, len(modelPriceMap)) + for k, v := range modelPriceMap { + copyMap[k] = v + } + return copyMap +} + +func GetCompletionRatioCopy() map[string]float64 { + CompletionRatioMutex.RLock() + defer CompletionRatioMutex.RUnlock() + copyMap := make(map[string]float64, len(CompletionRatio)) + for k, v := range CompletionRatio { + copyMap[k] = v + } + return copyMap +} diff --git a/web/src/components/settings/ChannelSelectorModal.js b/web/src/components/settings/ChannelSelectorModal.js new file mode 100644 index 00000000..c393d97f --- /dev/null +++ b/web/src/components/settings/ChannelSelectorModal.js @@ -0,0 +1,154 @@ +import React from 'react'; +import { + Modal, + Transfer, + Input, + Card, + Space, + Button, + Checkbox, +} from '@douyinfe/semi-ui'; +import { IconPlus, IconClose } from '@douyinfe/semi-icons'; + +/** + * ChannelSelectorModal + * 负责选择同步渠道、测试与批量测试等 UI,纯展示组件。 + * 业务状态与动作通过 props 注入,保持可复用与可测试。 + */ +export default function ChannelSelectorModal({ + t, + visible, + onCancel, + onOk, + // 渠道与选择 + allChannels = [], + selectedChannelIds = [], + setSelectedChannelIds, + // 自定义渠道 + customUrl, + setCustomUrl, + customEndpoint, + setCustomEndpoint, + customChannelTesting, + addCustomChannel, + // 渠道端点 + channelEndpoints, + updateChannelEndpoint, + // 测试相关 +}) { + // Transfer 自定义渲染 + const renderSourceItem = (item) => { + const channelId = item.key || item.value; + const currentEndpoint = channelEndpoints[channelId]; + const baseUrl = item._originalData?.base_url || ''; + + return ( +
+
+
+ + {item.label} + +
+
+ + {baseUrl} + + updateChannelEndpoint(channelId, value)} + placeholder="/api/ratio_config" + className="flex-1 text-xs" + style={{ fontSize: '12px' }} + /> +
+
+
+ ); + }; + + const renderSelectedItem = (item) => { + const channelId = item.key || item.value; + const currentEndpoint = channelEndpoints[channelId]; + const baseUrl = item._originalData?.base_url || ''; + + return ( +
+
+
+ {item.label} + +
+
+ + {baseUrl} + + + {currentEndpoint} + +
+
+
+ ); + }; + + const channelFilter = (input, item) => item.label.toLowerCase().includes(input.toLowerCase()); + + return ( + {t('选择同步渠道')}} + width={1000} + > + + + + + + + + + + + + + ); +} \ No newline at end of file diff --git a/web/src/components/settings/RatioSetting.js b/web/src/components/settings/RatioSetting.js index bf97282c..1d87c6de 100644 --- a/web/src/components/settings/RatioSetting.js +++ b/web/src/components/settings/RatioSetting.js @@ -6,6 +6,7 @@ import GroupRatioSettings from '../../pages/Setting/Ratio/GroupRatioSettings.js' import ModelRatioSettings from '../../pages/Setting/Ratio/ModelRatioSettings.js'; import ModelSettingsVisualEditor from '../../pages/Setting/Ratio/ModelSettingsVisualEditor.js'; import ModelRatioNotSetEditor from '../../pages/Setting/Ratio/ModelRationNotSetEditor.js'; +import UpstreamRatioSync from '../../pages/Setting/Ratio/UpstreamRatioSync.js'; import { API, showError } from '../../helpers'; @@ -21,6 +22,7 @@ const RatioSetting = () => { GroupGroupRatio: '', AutoGroups: '', DefaultUseAutoGroup: false, + ExposeRatioEnabled: false, UserUsableGroups: '', }); @@ -48,7 +50,7 @@ const RatioSetting = () => { // 如果后端返回的不是合法 JSON,直接展示 } } - if (['DefaultUseAutoGroup'].includes(item.key)) { + if (['DefaultUseAutoGroup', 'ExposeRatioEnabled'].includes(item.key)) { newInputs[item.key] = item.value === 'true' ? true : false; } else { newInputs[item.key] = item.value; @@ -78,10 +80,6 @@ const RatioSetting = () => { return ( - {/* 分组倍率设置 */} - - - {/* 模型倍率设置以及可视化编辑器 */} @@ -100,8 +98,18 @@ const RatioSetting = () => { refresh={onRefresh} /> + + + + {/* 分组倍率设置 */} + + + ); }; diff --git a/web/src/helpers/ratio.js b/web/src/helpers/ratio.js new file mode 100644 index 00000000..fb293c80 --- /dev/null +++ b/web/src/helpers/ratio.js @@ -0,0 +1,20 @@ +export const DEFAULT_ENDPOINT = '/api/ratio_config'; + +/** + * buildEndpointUrl: 拼接 baseUrl 与 endpoint,确保不会出现双斜杠或缺失斜杠问题。 + * 使用 URL 构造函数保证协议/域名安全;若 baseUrl 非标准 URL,则退回字符串拼接。 + * @param {string} baseUrl - 基础地址,例如 https://api.example.com + * @param {string} endpoint - 接口路径,例如 /api/ratio_config + * @returns {string} + */ +export const buildEndpointUrl = (baseUrl, endpoint) => { + if (!baseUrl) return endpoint; + try { + return new URL(endpoint, baseUrl).toString(); + } catch (_) { + // fallback 处理不规范的 baseUrl + const cleanedBase = baseUrl.endsWith('/') ? baseUrl.slice(0, -1) : baseUrl; + const cleanedEndpoint = endpoint.startsWith('/') ? endpoint.slice(1) : endpoint; + return `${cleanedBase}/${cleanedEndpoint}`; + } +}; \ No newline at end of file diff --git a/web/src/pages/Detail/index.js b/web/src/pages/Detail/index.js index 15c02abf..0fd18d16 100644 --- a/web/src/pages/Detail/index.js +++ b/web/src/pages/Detail/index.js @@ -1112,7 +1112,6 @@ const Detail = (props) => { @@ -1389,7 +1388,6 @@ const Detail = (props) => { ) : ( + + + + setInputs({ ...inputs, ExposeRatioEnabled: value }) + } + /> + + diff --git a/web/src/pages/Setting/Ratio/UpstreamRatioSync.js b/web/src/pages/Setting/Ratio/UpstreamRatioSync.js new file mode 100644 index 00000000..518c6468 --- /dev/null +++ b/web/src/pages/Setting/Ratio/UpstreamRatioSync.js @@ -0,0 +1,596 @@ +import React, { useState, useCallback, useMemo } from 'react'; +import { + Button, + Table, + Tag, + Empty, + Checkbox, + Form, +} from '@douyinfe/semi-ui'; +import { + RefreshCcw, + CheckSquare, +} from 'lucide-react'; +import { + DEFAULT_ENDPOINT, + buildEndpointUrl, +} from '../../../helpers/ratio'; +import { API, showError, showSuccess, showWarning } from '../../../helpers'; +import { useTranslation } from 'react-i18next'; +import { + IllustrationNoResult, + IllustrationNoResultDark +} from '@douyinfe/semi-illustrations'; +import ChannelSelectorModal from '../../../components/settings/ChannelSelectorModal'; + +export default function UpstreamRatioSync(props) { + const { t } = useTranslation(); + const [modalVisible, setModalVisible] = useState(false); + const [loading, setLoading] = useState(false); + const [syncLoading, setSyncLoading] = useState(false); + + // 渠道选择相关 + const [allChannels, setAllChannels] = useState([]); + const [selectedChannelIds, setSelectedChannelIds] = useState([]); + + // 自定义渠道 + const [customUrl, setCustomUrl] = useState(''); + const [customEndpoint, setCustomEndpoint] = useState(DEFAULT_ENDPOINT); + const [customChannelTesting, setCustomChannelTesting] = useState(false); + + // 渠道端点配置 + const [channelEndpoints, setChannelEndpoints] = useState({}); // { channelId: endpoint } + + // 差异数据和测试结果 + const [differences, setDifferences] = useState({}); + const [testResults, setTestResults] = useState([]); + const [resolutions, setResolutions] = useState({}); + + // 分页相关状态 + const [currentPage, setCurrentPage] = useState(1); + const [pageSize, setPageSize] = useState(10); + + // 当前倍率快照 + const currentRatiosSnapshot = useMemo(() => ({ + model_ratio: JSON.parse(props.options.ModelRatio || '{}'), + completion_ratio: JSON.parse(props.options.CompletionRatio || '{}'), + cache_ratio: JSON.parse(props.options.CacheRatio || '{}'), + model_price: JSON.parse(props.options.ModelPrice || '{}'), + }), [props.options]); + + // 获取所有渠道 + const fetchAllChannels = async () => { + setLoading(true); + try { + const res = await API.get('/api/ratio_sync/channels'); + + if (res.data.success) { + const channels = res.data.data || []; + + // 转换为Transfer组件所需格式 + const transferData = channels.map(channel => ({ + key: channel.id, + label: channel.name, + value: channel.id, + disabled: false, // 所有渠道都可以选择 + _originalData: channel, + })); + + setAllChannels(transferData); + + // 初始化端点配置 + const initialEndpoints = {}; + transferData.forEach(channel => { + initialEndpoints[channel.key] = DEFAULT_ENDPOINT; + }); + setChannelEndpoints(initialEndpoints); + } else { + showError(res.data.message); + } + } catch (error) { + showError(t('获取渠道失败:') + error.message); + } finally { + setLoading(false); + } + }; + + // 测试自定义渠道 + const testCustomChannel = async () => { + if (!customUrl) { + showWarning(t('请输入渠道地址')); + return false; + } + + setCustomChannelTesting(true); + + try { + const url = buildEndpointUrl(customUrl, customEndpoint); + const client = { timeout: 10000 }; + + const response = await fetch(url, { + method: 'GET', + signal: AbortSignal.timeout(client.timeout) + }); + + if (response.ok) { + const data = await response.json(); + if (data.success) { + return true; + } else { + showError(t('测试失败') + `: ${data.message || t('响应格式错误')}`); + return false; + } + } else { + showError(t('测试失败') + `: HTTP ${response.status}`); + return false; + } + } catch (error) { + showError(t('测试失败') + `: ${error.message || t('请求超时')}`); + return false; + } finally { + setCustomChannelTesting(false); + } + }; + + // 添加自定义渠道 + const addCustomChannel = async () => { + if (!customUrl) { + showWarning(t('请输入渠道地址')); + return; + } + + // 先测试渠道 + const testResult = await testCustomChannel(); + if (!testResult) { + return; + } + + let hostname; + try { + hostname = new URL(customUrl).hostname; + } catch (e) { + hostname = customUrl; + } + + const customId = `custom_${Date.now()}`; + const newChannel = { + key: customId, + label: hostname, + value: customId, + disabled: false, + _originalData: { + id: customId, + name: hostname, + base_url: customUrl.endsWith('/') ? customUrl.slice(0, -1) : customUrl, + status: 1, + is_custom: true, + }, + }; + + setAllChannels([...allChannels, newChannel]); + setSelectedChannelIds([...selectedChannelIds, customId]); + setChannelEndpoints(prev => ({ ...prev, [customId]: customEndpoint })); + setCustomUrl(''); + showSuccess(t('测试成功,渠道添加成功')); + }; + + // 确认选择渠道 + const confirmChannelSelection = () => { + const selected = allChannels + .filter(ch => selectedChannelIds.includes(ch.value)) + .map(ch => ch._originalData); + + if (selected.length === 0) { + showWarning(t('请至少选择一个渠道')); + return; + } + + setModalVisible(false); + fetchRatiosFromChannels(selected); + }; + + // 从选定渠道获取倍率 + const fetchRatiosFromChannels = async (channelList) => { + setSyncLoading(true); + + // 分离数据库渠道和自定义渠道 + const dbChannels = channelList.filter(ch => !ch.is_custom); + const customChannels = channelList.filter(ch => ch.is_custom); + + const payload = { + channel_ids: dbChannels.map(ch => parseInt(ch.id)), + custom_channels: customChannels.map(ch => ({ + name: ch.name, + base_url: ch.base_url, + endpoint: channelEndpoints[ch.id] || DEFAULT_ENDPOINT, + })), + timeout: 10 + }; + + try { + const res = await API.post('/api/ratio_sync/fetch', payload); + + if (!res.data.success) { + showError(res.data.message || t('后端请求失败')); + setSyncLoading(false); + return; + } + + const { differences = {}, test_results = [] } = res.data.data; + + // 显示测试结果 + const errorResults = test_results.filter(r => r.status === 'error'); + if (errorResults.length > 0) { + showWarning(t('部分渠道测试失败:') + errorResults.map(r => `${r.name}: ${r.error}`).join(', ')); + } + + setDifferences(differences); + setTestResults(test_results); + setResolutions({}); + + // 判断是否有差异 + if (Object.keys(differences).length === 0) { + showSuccess(t('已与上游倍率完全一致,无需同步')); + } + } catch (e) { + showError(t('请求后端接口失败:') + e.message); + } finally { + setSyncLoading(false); + } + }; + + // 解决冲突/选择值 + const selectValue = (model, ratioType, value) => { + setResolutions(prev => ({ + ...prev, + [model]: { + ...prev[model], + [ratioType]: value, + }, + })); + }; + + // 应用同步 + const applySync = async () => { + const currentRatios = { + ModelRatio: JSON.parse(props.options.ModelRatio || '{}'), + CompletionRatio: JSON.parse(props.options.CompletionRatio || '{}'), + CacheRatio: JSON.parse(props.options.CacheRatio || '{}'), + ModelPrice: JSON.parse(props.options.ModelPrice || '{}'), + }; + + // 应用已选择的值 + Object.entries(resolutions).forEach(([model, ratios]) => { + Object.entries(ratios).forEach(([ratioType, value]) => { + const optionKey = ratioType + .split('_') + .map(word => word.charAt(0).toUpperCase() + word.slice(1)) + .join(''); + currentRatios[optionKey][model] = parseFloat(value); + }); + }); + + // 保存到后端 + setLoading(true); + try { + const updates = Object.entries(currentRatios).map(([key, value]) => + API.put('/api/option/', { + key, + value: JSON.stringify(value, null, 2), + }) + ); + + const results = await Promise.all(updates); + + if (results.every(res => res.data.success)) { + showSuccess(t('同步成功')); + props.refresh(); + // 清空状态 + setDifferences({}); + setTestResults([]); + setResolutions({}); + setSelectedChannelIds([]); + } else { + showError(t('部分保存失败')); + } + } catch (error) { + showError(t('保存失败')); + } finally { + setLoading(false); + } + }; + + // 计算当前页显示的数据 + const getCurrentPageData = (dataSource) => { + const startIndex = (currentPage - 1) * pageSize; + const endIndex = startIndex + pageSize; + return dataSource.slice(startIndex, endIndex); + }; + + // 渲染表格头部 + const renderHeader = () => ( +
+
+
+ + + {(() => { + // 检查是否有选择可应用的值 + const hasSelections = Object.keys(resolutions).length > 0; + + return ( + + ); + })()} +
+
+
+ ); + + // 渲染差异表格 + const renderDifferenceTable = () => { + // 构建数据源 + const dataSource = useMemo(() => { + const tmp = []; + + Object.entries(differences).forEach(([model, ratioTypes]) => { + Object.entries(ratioTypes).forEach(([ratioType, diff]) => { + tmp.push({ + key: `${model}_${ratioType}`, + model, + ratioType, + current: diff.current, + upstreams: diff.upstreams, + }); + }); + }); + + return tmp; + }, [differences]); + + // 收集所有上游渠道名称 + const upstreamNames = useMemo(() => { + const set = new Set(); + dataSource.forEach((row) => { + Object.keys(row.upstreams || {}).forEach((name) => set.add(name)); + }); + return Array.from(set); + }, [dataSource]); + + if (dataSource.length === 0) { + return ( + } + darkModeImage={} + description={Object.keys(differences).length === 0 ? t('已与上游倍率完全一致') : t('请先选择同步渠道')} + style={{ padding: 30 }} + /> + ); + } + + // 列定义 + const columns = [ + { + title: t('模型'), + dataIndex: 'model', + fixed: 'left', + width: 160, + }, + { + title: t('倍率类型'), + dataIndex: 'ratioType', + width: 140, + render: (text) => { + const typeMap = { + model_ratio: t('模型倍率'), + completion_ratio: t('补全倍率'), + cache_ratio: t('缓存倍率'), + model_price: t('固定价格'), + }; + return {typeMap[text] || text}; + }, + }, + { + title: t('当前值'), + dataIndex: 'current', + width: 100, + render: (text) => ( + + {text !== null && text !== undefined ? text : t('未设置')} + + ), + }, + // 动态上游列 + ...upstreamNames.map((upName) => { + // 计算该渠道的全选状态 + const channelStats = (() => { + let selectableCount = 0; // 可选择的项目数量 + let selectedCount = 0; // 已选择的项目数量 + + dataSource.forEach((row) => { + const upstreamVal = row.upstreams?.[upName]; + // 只有具体数值的才是可选择的(不是null、undefined或"same") + if (upstreamVal !== null && upstreamVal !== undefined && upstreamVal !== 'same') { + selectableCount++; + const isSelected = resolutions[row.model]?.[row.ratioType] === upstreamVal; + if (isSelected) { + selectedCount++; + } + } + }); + + return { + selectableCount, + selectedCount, + allSelected: selectableCount > 0 && selectedCount === selectableCount, + partiallySelected: selectedCount > 0 && selectedCount < selectableCount, + hasSelectableItems: selectableCount > 0 + }; + })(); + + // 处理全选/取消全选 + const handleBulkSelect = (checked) => { + setResolutions((prev) => { + const newRes = { ...prev }; + + dataSource.forEach((row) => { + const upstreamVal = row.upstreams?.[upName]; + if (upstreamVal !== null && upstreamVal !== undefined && upstreamVal !== 'same') { + if (checked) { + // 选择该值 + if (!newRes[row.model]) newRes[row.model] = {}; + newRes[row.model][row.ratioType] = upstreamVal; + } else { + // 取消选择该值 + if (newRes[row.model]) { + delete newRes[row.model][row.ratioType]; + if (Object.keys(newRes[row.model]).length === 0) { + delete newRes[row.model]; + } + } + } + } + }); + + return newRes; + }); + }; + + return { + title: channelStats.hasSelectableItems ? ( + handleBulkSelect(e.target.checked)} + > + {upName} + + ) : ( + {upName} + ), + dataIndex: upName, + width: 140, + render: (_, record) => { + const upstreamVal = record.upstreams?.[upName]; + + if (upstreamVal === null || upstreamVal === undefined) { + return {t('未设置')}; + } + + if (upstreamVal === 'same') { + return {t('与本地相同')}; + } + + // 有具体值,可以选择 + const isSelected = resolutions[record.model]?.[record.ratioType] === upstreamVal; + + return ( + { + const isChecked = e.target.checked; + if (isChecked) { + selectValue(record.model, record.ratioType, upstreamVal); + } else { + setResolutions((prev) => { + const newRes = { ...prev }; + if (newRes[record.model]) { + delete newRes[record.model][record.ratioType]; + if (Object.keys(newRes[record.model]).length === 0) { + delete newRes[record.model]; + } + } + return newRes; + }); + } + }} + > + {upstreamVal} + + ); + }, + }; + }), + ]; + + return ( + t('第 {{start}} - {{end}} 条,共 {{total}} 条', { + start: page.currentStart, + end: page.currentEnd, + total: dataSource.length, + }), + pageSizeOptions: ['5', '10', '20', '50'], + onChange: (page, size) => { + setCurrentPage(page); + setPageSize(size); + }, + onShowSizeChange: (current, size) => { + setCurrentPage(1); + setPageSize(size); + } + }} + scroll={{ x: 'max-content' }} + size='middle' + loading={loading || syncLoading} + className="rounded-xl overflow-hidden" + /> + ); + }; + + // 更新渠道端点 + const updateChannelEndpoint = useCallback((channelId, endpoint) => { + setChannelEndpoints(prev => ({ ...prev, [channelId]: endpoint })); + }, []); + + return ( + <> + + {renderDifferenceTable()} + + + setModalVisible(false)} + onOk={confirmChannelSelection} + allChannels={allChannels} + selectedChannelIds={selectedChannelIds} + setSelectedChannelIds={setSelectedChannelIds} + customUrl={customUrl} + setCustomUrl={setCustomUrl} + customEndpoint={customEndpoint} + setCustomEndpoint={setCustomEndpoint} + customChannelTesting={customChannelTesting} + addCustomChannel={addCustomChannel} + channelEndpoints={channelEndpoints} + updateChannelEndpoint={updateChannelEndpoint} + /> + + ); +} \ No newline at end of file From 8a79de333a580243885af38b5eebeb329ea32192 Mon Sep 17 00:00:00 2001 From: skynono Date: Sun, 8 Jun 2025 21:40:57 +0800 Subject: [PATCH 02/19] feat: add video channel kling --- common/constants.go | 2 + constant/task.go | 1 + controller/relay.go | 2 +- controller/task.go | 2 + controller/task_video.go | 142 ++++++++++ dto/video.go | 47 ++++ middleware/distributor.go | 9 + relay/channel/adapter.go | 2 + relay/channel/task/kling/adaptor.go | 312 ++++++++++++++++++++++ relay/channel/task/suno/adaptor.go | 4 + relay/constant/relay_mode.go | 13 + relay/relay_adaptor.go | 3 + relay/relay_task.go | 32 ++- router/main.go | 1 + router/video-router.go | 17 ++ web/src/components/table/TaskLogsTable.js | 31 ++- web/src/constants/channel.constants.js | 5 + 17 files changed, 619 insertions(+), 6 deletions(-) create mode 100644 controller/task_video.go create mode 100644 dto/video.go create mode 100644 relay/channel/task/kling/adaptor.go create mode 100644 router/video-router.go diff --git a/common/constants.go b/common/constants.go index bee00506..ac803148 100644 --- a/common/constants.go +++ b/common/constants.go @@ -241,6 +241,7 @@ const ( ChannelTypeXinference = 47 ChannelTypeXai = 48 ChannelTypeCoze = 49 + ChannelTypeKling = 50 ChannelTypeDummy // this one is only for count, do not add any channel after this ) @@ -296,4 +297,5 @@ var ChannelBaseURLs = []string{ "", //47 "https://api.x.ai", //48 "https://api.coze.cn", //49 + "https://api.klingai.com", //50 } diff --git a/constant/task.go b/constant/task.go index 1a68b812..d466fc8a 100644 --- a/constant/task.go +++ b/constant/task.go @@ -5,6 +5,7 @@ type TaskPlatform string const ( TaskPlatformSuno TaskPlatform = "suno" TaskPlatformMidjourney = "mj" + TaskPlatformKling TaskPlatform = "kling" ) const ( diff --git a/controller/relay.go b/controller/relay.go index c1c45114..4da4262b 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -420,7 +420,7 @@ func RelayTask(c *gin.Context) { func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError { var err *dto.TaskError switch relayMode { - case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID: + case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeKlingFetchByID: err = relay.RelayTaskFetch(c, relayMode) default: err = relay.RelayTaskSubmit(c, relayMode) diff --git a/controller/task.go b/controller/task.go index 34e14f3f..f7523e87 100644 --- a/controller/task.go +++ b/controller/task.go @@ -74,6 +74,8 @@ func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][ //_ = UpdateMidjourneyTaskAll(context.Background(), tasks) case constant.TaskPlatformSuno: _ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM) + case constant.TaskPlatformKling: + _ = UpdateVideoTaskAll(context.Background(), taskChannelM, taskM) default: common.SysLog("未知平台") } diff --git a/controller/task_video.go b/controller/task_video.go new file mode 100644 index 00000000..3f2c9588 --- /dev/null +++ b/controller/task_video.go @@ -0,0 +1,142 @@ +package controller + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "one-api/common" + "one-api/constant" + "one-api/model" + "one-api/relay" + "one-api/relay/channel" +) + +func UpdateVideoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error { + for channelId, taskIds := range taskChannelM { + if err := updateVideoTaskAll(ctx, channelId, taskIds, taskM); err != nil { + common.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error())) + } + } + return nil +} + +func updateVideoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error { + common.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds))) + if len(taskIds) == 0 { + return nil + } + cacheGetChannel, err := model.CacheGetChannel(channelId) + if err != nil { + errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{ + "fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId), + "status": "FAILURE", + "progress": "100%", + }) + if errUpdate != nil { + common.SysError(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate)) + } + return fmt.Errorf("CacheGetChannel failed: %w", err) + } + adaptor := relay.GetTaskAdaptor(constant.TaskPlatformKling) + if adaptor == nil { + return fmt.Errorf("video adaptor not found") + } + for _, taskId := range taskIds { + if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil { + common.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error())) + } + } + return nil +} + +func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error { + baseURL := common.ChannelBaseURLs[channel.Type] + if channel.GetBaseURL() != "" { + baseURL = channel.GetBaseURL() + } + resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{ + "task_id": taskId, + }) + if err != nil { + return fmt.Errorf("FetchTask failed for task %s: %w", taskId, err) + } + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("Get Video Task status code: %d", resp.StatusCode) + } + defer resp.Body.Close() + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("ReadAll failed for task %s: %w", taskId, err) + } + + var responseItem map[string]interface{} + err = json.Unmarshal(responseBody, &responseItem) + if err != nil { + common.LogError(ctx, fmt.Sprintf("Failed to parse video task response body: %v, body: %s", err, string(responseBody))) + return fmt.Errorf("Unmarshal failed for task %s: %w", taskId, err) + } + + code, _ := responseItem["code"].(float64) + if code != 0 { + return fmt.Errorf("video task fetch failed for task %s", taskId) + } + + data, ok := responseItem["data"].(map[string]interface{}) + if !ok { + common.LogError(ctx, fmt.Sprintf("Video task data format error: %s", string(responseBody))) + return fmt.Errorf("video task data format error for task %s", taskId) + } + + task := taskM[taskId] + if task == nil { + common.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId)) + return fmt.Errorf("task %s not found", taskId) + } + + if status, ok := data["task_status"].(string); ok { + switch status { + case "submitted", "queued": + task.Status = model.TaskStatusSubmitted + case "processing": + task.Status = model.TaskStatusInProgress + case "succeed": + task.Status = model.TaskStatusSuccess + task.Progress = "100%" + if url, err := adaptor.(interface { + ParseResultUrl(map[string]any) (string, error) + }).ParseResultUrl(responseItem); err == nil { + task.FailReason = url + } else { + common.LogWarn(ctx, fmt.Sprintf("Failed to get url from body for task %s: %s", task.TaskID, err.Error())) + } + case "failed": + task.Status = model.TaskStatusFailure + task.Progress = "100%" + if reason, ok := data["fail_reason"].(string); ok { + task.FailReason = reason + } + } + } + + // If task failed, refund quota + if task.Status == model.TaskStatusFailure { + common.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason)) + quota := task.Quota + if quota != 0 { + if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil { + common.LogError(ctx, "Failed to increase user quota: "+err.Error()) + } + logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, common.LogQuota(quota)) + model.RecordLog(task.UserId, model.LogTypeSystem, logContent) + } + } + + task.Data = responseBody + if err := task.Update(); err != nil { + common.SysError("UpdateVideoTask task error: " + err.Error()) + } + + return nil +} diff --git a/dto/video.go b/dto/video.go new file mode 100644 index 00000000..5b48146a --- /dev/null +++ b/dto/video.go @@ -0,0 +1,47 @@ +package dto + +type VideoRequest struct { + Model string `json:"model,omitempty" example:"kling-v1"` // Model/style ID + Prompt string `json:"prompt,omitempty" example:"宇航员站起身走了"` // Text prompt + Image string `json:"image,omitempty" example:"https://h2.inkwai.com/bs2/upload-ylab-stunt/se/ai_portal_queue_mmu_image_upscale_aiweb/3214b798-e1b4-4b00-b7af-72b5b0417420_raw_image_0.jpg"` // Image input (URL/Base64) + Duration float64 `json:"duration" example:"5.0"` // Video duration (seconds) + Width int `json:"width" example:"512"` // Video width + Height int `json:"height" example:"512"` // Video height + Fps int `json:"fps,omitempty" example:"30"` // Video frame rate + Seed int `json:"seed,omitempty" example:"20231234"` // Random seed + N int `json:"n,omitempty" example:"1"` // Number of videos to generate + ResponseFormat string `json:"response_format,omitempty" example:"url"` // Response format + User string `json:"user,omitempty" example:"user-1234"` // User identifier + Metadata map[string]any `json:"metadata,omitempty"` // Vendor-specific/custom params (e.g. negative_prompt, style, quality_level, etc.) +} + +// VideoResponse 视频生成提交任务后的响应 +type VideoResponse struct { + TaskId string `json:"task_id"` + Status string `json:"status"` +} + +// VideoTaskResponse 查询视频生成任务状态的响应 +type VideoTaskResponse struct { + TaskId string `json:"task_id" example:"abcd1234efgh"` // 任务ID + Status string `json:"status" example:"succeeded"` // 任务状态 + Url string `json:"url,omitempty"` // 视频资源URL(成功时) + Format string `json:"format,omitempty" example:"mp4"` // 视频格式 + Metadata *VideoTaskMetadata `json:"metadata,omitempty"` // 结果元数据 + Error *VideoTaskError `json:"error,omitempty"` // 错误信息(失败时) +} + +// VideoTaskMetadata 视频任务元数据 +type VideoTaskMetadata struct { + Duration float64 `json:"duration" example:"5.0"` // 实际生成的视频时长 + Fps int `json:"fps" example:"30"` // 实际帧率 + Width int `json:"width" example:"512"` // 实际宽度 + Height int `json:"height" example:"512"` // 实际高度 + Seed int `json:"seed" example:"20231234"` // 使用的随机种子 +} + +// VideoTaskError 视频任务错误信息 +type VideoTaskError struct { + Code int `json:"code"` + Message string `json:"message"` +} diff --git a/middleware/distributor.go b/middleware/distributor.go index 84eb182e..9d074ce8 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -170,6 +170,15 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { } c.Set("platform", string(constant.TaskPlatformSuno)) c.Set("relay_mode", relayMode) + } else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") { + relayMode := relayconstant.Path2RelayKling(c.Request.Method, c.Request.URL.Path) + if relayMode == relayconstant.RelayModeKlingFetchByID { + shouldSelectChannel = false + } else { + err = common.UnmarshalBodyReusable(c, &modelRequest) + } + c.Set("platform", string(constant.TaskPlatformKling)) + c.Set("relay_mode", relayMode) } else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") { // Gemini API 路径处理: /v1beta/models/gemini-2.0-flash:generateContent relayMode := relayconstant.RelayModeGemini diff --git a/relay/channel/adapter.go b/relay/channel/adapter.go index 50255d0a..873997f6 100644 --- a/relay/channel/adapter.go +++ b/relay/channel/adapter.go @@ -44,4 +44,6 @@ type TaskAdaptor interface { // FetchTask FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) + + ParseResultUrl(resp map[string]any) (string, error) } diff --git a/relay/channel/task/kling/adaptor.go b/relay/channel/task/kling/adaptor.go new file mode 100644 index 00000000..9c6773f5 --- /dev/null +++ b/relay/channel/task/kling/adaptor.go @@ -0,0 +1,312 @@ +package kling + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/gin-gonic/gin" + "github.com/golang-jwt/jwt" + "github.com/pkg/errors" + + "one-api/common" + "one-api/dto" + "one-api/relay/channel" + relaycommon "one-api/relay/common" + "one-api/service" +) + +// ============================ +// Request / Response structures +// ============================ + +type SubmitReq struct { + Prompt string `json:"prompt"` + Model string `json:"model,omitempty"` + Mode string `json:"mode,omitempty"` + Image string `json:"image,omitempty"` + Size string `json:"size,omitempty"` + Duration int `json:"duration,omitempty"` + Metadata map[string]interface{} `json:"metadata,omitempty"` +} + +type requestPayload struct { + Prompt string `json:"prompt,omitempty"` + Image string `json:"image,omitempty"` + Mode string `json:"mode,omitempty"` + Duration string `json:"duration,omitempty"` + AspectRatio string `json:"aspect_ratio,omitempty"` + Model string `json:"model,omitempty"` + ModelName string `json:"model_name,omitempty"` + CfgScale float64 `json:"cfg_scale,omitempty"` +} + +type responsePayload struct { + Code int `json:"code"` + Message string `json:"message"` + Data struct { + TaskID string `json:"task_id"` + } `json:"data"` +} + +// ============================ +// Adaptor implementation +// ============================ + +type TaskAdaptor struct { + ChannelType int + accessKey string + secretKey string + baseURL string +} + +func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) { + a.ChannelType = info.ChannelType + a.baseURL = info.BaseUrl + + // apiKey format: "access_key,secret_key" + keyParts := strings.Split(info.ApiKey, ",") + if len(keyParts) == 2 { + a.accessKey = strings.TrimSpace(keyParts[0]) + a.secretKey = strings.TrimSpace(keyParts[1]) + } +} + +// ValidateRequestAndSetAction parses body, validates fields and sets default action. +func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) { + // Accept only POST /v1/video/generations as "generate" action. + action := "generate" + info.Action = action + + var req SubmitReq + if err := common.UnmarshalBodyReusable(c, &req); err != nil { + taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest) + return + } + if strings.TrimSpace(req.Prompt) == "" { + taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest) + return + } + + // Store into context for later usage + c.Set("kling_request", req) + return nil +} + +// BuildRequestURL constructs the upstream URL. +func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) { + return fmt.Sprintf("%s/v1/videos/image2video", a.baseURL), nil +} + +// BuildRequestHeader sets required headers. +func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error { + token, err := a.createJWTToken() + if err != nil { + token = info.ApiKey // fallback + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("User-Agent", "kling-sdk/1.0") + return nil +} + +// BuildRequestBody converts request into Kling specific format. +func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) { + v, exists := c.Get("kling_request") + if !exists { + return nil, fmt.Errorf("request not found in context") + } + req := v.(SubmitReq) + + body := a.convertToRequestPayload(&req) + data, err := json.Marshal(body) + if err != nil { + return nil, err + } + return bytes.NewReader(data), nil +} + +// DoRequest delegates to common helper. +func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) { + return channel.DoTaskApiRequest(a, c, info, requestBody) +} + +// DoResponse handles upstream response, returns taskID etc. +func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + return + } + + // Attempt Kling response parse first. + var kResp responsePayload + if err := json.Unmarshal(responseBody, &kResp); err == nil && kResp.Code == 0 { + c.JSON(http.StatusOK, gin.H{"task_id": kResp.Data.TaskID}) + return kResp.Data.TaskID, responseBody, nil + } + + // Fallback generic task response. + var generic dto.TaskResponse[string] + if err := json.Unmarshal(responseBody, &generic); err != nil { + taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError) + return + } + + if !generic.IsSuccess() { + taskErr = service.TaskErrorWrapper(fmt.Errorf(generic.Message), generic.Code, http.StatusInternalServerError) + return + } + + c.JSON(http.StatusOK, gin.H{"task_id": generic.Data}) + return generic.Data, responseBody, nil +} + +// FetchTask fetch task status +func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) { + taskID, ok := body["task_id"].(string) + if !ok { + return nil, fmt.Errorf("invalid task_id") + } + url := fmt.Sprintf("%s/v1/videos/image2video/%s", baseUrl, taskID) + + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return nil, err + } + + token, err := a.createJWTTokenWithKey(key) + if err != nil { + token = key + } + + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + req = req.WithContext(ctx) + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("User-Agent", "kling-sdk/1.0") + + return service.GetHttpClient().Do(req) +} + +func (a *TaskAdaptor) GetModelList() []string { + return []string{"kling-v1", "kling-v1-6", "kling-v2-master"} +} + +func (a *TaskAdaptor) GetChannelName() string { + return "kling" +} + +// ============================ +// helpers +// ============================ + +func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) *requestPayload { + r := &requestPayload{ + Prompt: req.Prompt, + Image: req.Image, + Mode: defaultString(req.Mode, "std"), + Duration: fmt.Sprintf("%d", defaultInt(req.Duration, 5)), + AspectRatio: a.getAspectRatio(req.Size), + Model: req.Model, + ModelName: req.Model, + CfgScale: 0.5, + } + if r.Model == "" { + r.Model = "kling-v1" + r.ModelName = "kling-v1" + } + return r +} + +func (a *TaskAdaptor) getAspectRatio(size string) string { + switch size { + case "1024x1024", "512x512": + return "1:1" + case "1280x720", "1920x1080": + return "16:9" + case "720x1280", "1080x1920": + return "9:16" + default: + return "1:1" + } +} + +func defaultString(s, def string) string { + if strings.TrimSpace(s) == "" { + return def + } + return s +} + +func defaultInt(v int, def int) int { + if v == 0 { + return def + } + return v +} + +// ============================ +// JWT helpers +// ============================ + +func (a *TaskAdaptor) createJWTToken() (string, error) { + return a.createJWTTokenWithKeys(a.accessKey, a.secretKey) +} + +func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) { + parts := strings.Split(apiKey, ",") + if len(parts) != 2 { + return "", fmt.Errorf("invalid API key format, expected 'access_key,secret_key'") + } + return a.createJWTTokenWithKeys(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1])) +} + +func (a *TaskAdaptor) createJWTTokenWithKeys(accessKey, secretKey string) (string, error) { + if accessKey == "" || secretKey == "" { + return "", fmt.Errorf("access key and secret key are required") + } + now := time.Now().Unix() + claims := jwt.MapClaims{ + "iss": accessKey, + "exp": now + 1800, // 30 minutes + "nbf": now - 5, + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + token.Header["typ"] = "JWT" + return token.SignedString([]byte(secretKey)) +} + +// ParseResultUrl 提取视频任务结果的 url +func (a *TaskAdaptor) ParseResultUrl(resp map[string]any) (string, error) { + data, ok := resp["data"].(map[string]any) + if !ok { + return "", fmt.Errorf("data field not found or invalid") + } + taskResult, ok := data["task_result"].(map[string]any) + if !ok { + return "", fmt.Errorf("task_result field not found or invalid") + } + videos, ok := taskResult["videos"].([]interface{}) + if !ok || len(videos) == 0 { + return "", fmt.Errorf("videos field not found or empty") + } + video, ok := videos[0].(map[string]interface{}) + if !ok { + return "", fmt.Errorf("video item invalid") + } + url, ok := video["url"].(string) + if !ok || url == "" { + return "", fmt.Errorf("url field not found or invalid") + } + return url, nil +} diff --git a/relay/channel/task/suno/adaptor.go b/relay/channel/task/suno/adaptor.go index 03d60516..f7042348 100644 --- a/relay/channel/task/suno/adaptor.go +++ b/relay/channel/task/suno/adaptor.go @@ -22,6 +22,10 @@ type TaskAdaptor struct { ChannelType int } +func (a *TaskAdaptor) ParseResultUrl(resp map[string]any) (string, error) { + return "", nil // todo implement this method if needed +} + func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) { a.ChannelType = info.ChannelType } diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go index f22a20bd..02a286e2 100644 --- a/relay/constant/relay_mode.go +++ b/relay/constant/relay_mode.go @@ -38,6 +38,9 @@ const ( RelayModeSunoFetchByID RelayModeSunoSubmit + RelayModeKlingFetchByID + RelayModeKlingSubmit + RelayModeRerank RelayModeResponses @@ -133,3 +136,13 @@ func Path2RelaySuno(method, path string) int { } return relayMode } + +func Path2RelayKling(method, path string) int { + relayMode := RelayModeUnknown + if method == http.MethodPost && strings.HasSuffix(path, "/video/generations") { + relayMode = RelayModeKlingSubmit + } else if method == http.MethodGet && strings.Contains(path, "/video/generations/") { + relayMode = RelayModeKlingFetchByID + } + return relayMode +} diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go index 7bf0da9f..626bb7e4 100644 --- a/relay/relay_adaptor.go +++ b/relay/relay_adaptor.go @@ -22,6 +22,7 @@ import ( "one-api/relay/channel/palm" "one-api/relay/channel/perplexity" "one-api/relay/channel/siliconflow" + "one-api/relay/channel/task/kling" "one-api/relay/channel/task/suno" "one-api/relay/channel/tencent" "one-api/relay/channel/vertex" @@ -101,6 +102,8 @@ func GetTaskAdaptor(platform commonconstant.TaskPlatform) channel.TaskAdaptor { // return &aiproxy.Adaptor{} case commonconstant.TaskPlatformSuno: return &suno.TaskAdaptor{} + case commonconstant.TaskPlatformKling: + return &kling.TaskAdaptor{} } return nil } diff --git a/relay/relay_task.go b/relay/relay_task.go index 3da9a20f..245fd681 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -37,6 +37,9 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { } modelName := service.CoverTaskActionToModelName(platform, relayInfo.Action) + if platform == constant.TaskPlatformKling { + modelName = relayInfo.OriginModelName + } modelPrice, success := ratio_setting.GetModelPrice(modelName, true) if !success { defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[modelName] @@ -136,10 +139,11 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { } relayInfo.ConsumeQuota = true // insert task - task := model.InitTask(constant.TaskPlatformSuno, relayInfo) + task := model.InitTask(platform, relayInfo) task.TaskID = taskID task.Quota = quota task.Data = taskData + task.Action = relayInfo.Action err = task.Insert() if err != nil { taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError) @@ -149,8 +153,9 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { } var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){ - relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder, - relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder, + relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder, + relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder, + relayconstant.RelayModeKlingFetchByID: videoFetchByIDRespBodyBuilder, } func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) { @@ -225,6 +230,27 @@ func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dt return } +func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) { + taskId := c.Param("id") + userId := c.GetInt("id") + + originTask, exist, err := model.GetByTaskId(userId, taskId) + if err != nil { + taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError) + return + } + if !exist { + taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest) + return + } + + respBody, err = json.Marshal(dto.TaskResponse[any]{ + Code: "success", + Data: TaskModel2Dto(originTask), + }) + return +} + func TaskModel2Dto(task *model.Task) *dto.TaskDto { return &dto.TaskDto{ TaskID: task.TaskID, diff --git a/router/main.go b/router/main.go index b8ac4055..0d2bfdce 100644 --- a/router/main.go +++ b/router/main.go @@ -14,6 +14,7 @@ func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) { SetApiRouter(router) SetDashboardRouter(router) SetRelayRouter(router) + SetVideoRouter(router) frontendBaseUrl := os.Getenv("FRONTEND_BASE_URL") if common.IsMasterNode && frontendBaseUrl != "" { frontendBaseUrl = "" diff --git a/router/video-router.go b/router/video-router.go new file mode 100644 index 00000000..7201c34a --- /dev/null +++ b/router/video-router.go @@ -0,0 +1,17 @@ +package router + +import ( + "one-api/controller" + "one-api/middleware" + + "github.com/gin-gonic/gin" +) + +func SetVideoRouter(router *gin.Engine) { + videoV1Router := router.Group("/v1") + videoV1Router.Use(middleware.TokenAuth(), middleware.Distribute()) + { + videoV1Router.POST("/video/generations", controller.RelayTask) + videoV1Router.GET("/video/generations/:task_id", controller.RelayTask) + } +} diff --git a/web/src/components/table/TaskLogsTable.js b/web/src/components/table/TaskLogsTable.js index b3d0ab7b..37bdde57 100644 --- a/web/src/components/table/TaskLogsTable.js +++ b/web/src/components/table/TaskLogsTable.js @@ -11,7 +11,9 @@ import { XCircle, Loader, List, - Hash + Hash, + Video, + Sparkles } from 'lucide-react'; import { API, @@ -80,6 +82,7 @@ const COLUMN_KEYS = { TASK_STATUS: 'task_status', PROGRESS: 'progress', FAIL_REASON: 'fail_reason', + RESULT_URL: 'result_url', }; const renderTimestamp = (timestampInSeconds) => { @@ -150,6 +153,7 @@ const LogsTable = () => { [COLUMN_KEYS.TASK_STATUS]: true, [COLUMN_KEYS.PROGRESS]: true, [COLUMN_KEYS.FAIL_REASON]: true, + [COLUMN_KEYS.RESULT_URL]: true, }; }; @@ -203,6 +207,12 @@ const LogsTable = () => { {t('生成歌词')} ); + case 'generate': + return ( + }> + {t('生成视频')} + + ); default: return ( }> @@ -220,6 +230,12 @@ const LogsTable = () => { Suno ); + case 'kling': + return ( + }> + Kling + + ); default: return ( }> @@ -411,10 +427,21 @@ const LogsTable = () => { }, { key: COLUMN_KEYS.FAIL_REASON, - title: t('失败原因'), + title: t('详情'), dataIndex: 'fail_reason', fixed: 'right', render: (text, record, index) => { + // 仅当为视频生成任务且成功,且 fail_reason 是 URL 时显示可点击链接 + const isVideoTask = record.action === 'generate'; + const isSuccess = record.status === 'SUCCESS'; + const isUrl = typeof text === 'string' && /^https?:\/\//.test(text); + if (isSuccess && isVideoTask && isUrl) { + return ( + + {t('点击预览视频')} + + ); + } if (!text) { return t('无'); } diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index 20fed5b7..c4220bd4 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -125,4 +125,9 @@ export const CHANNEL_OPTIONS = [ color: 'blue', label: 'Coze', }, + { + value: 50, + color: 'green', + label: '可灵', + }, ]; From b7c77777a57e0df7be1d05f2ec4dfcb0bcb3e5ca Mon Sep 17 00:00:00 2001 From: skynono Date: Tue, 17 Jun 2025 13:37:07 +0800 Subject: [PATCH 03/19] feat: add video channel kling fix --- controller/task_video.go | 4 +--- relay/channel/task/kling/adaptor.go | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/controller/task_video.go b/controller/task_video.go index 3f2c9588..a2c2431d 100644 --- a/controller/task_video.go +++ b/controller/task_video.go @@ -104,9 +104,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha case "succeed": task.Status = model.TaskStatusSuccess task.Progress = "100%" - if url, err := adaptor.(interface { - ParseResultUrl(map[string]any) (string, error) - }).ParseResultUrl(responseItem); err == nil { + if url, err := adaptor.ParseResultUrl(responseItem); err == nil { task.FailReason = url } else { common.LogWarn(ctx, fmt.Sprintf("Failed to get url from body for task %s: %s", task.TaskID, err.Error())) diff --git a/relay/channel/task/kling/adaptor.go b/relay/channel/task/kling/adaptor.go index 9c6773f5..9ea58728 100644 --- a/relay/channel/task/kling/adaptor.go +++ b/relay/channel/task/kling/adaptor.go @@ -107,7 +107,7 @@ func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error { token, err := a.createJWTToken() if err != nil { - token = info.ApiKey // fallback + return fmt.Errorf("failed to create JWT token: %w", err) } req.Header.Set("Content-Type", "application/json") From 616e6953ccc2ce417986c8959dadc54a9dbad1dc Mon Sep 17 00:00:00 2001 From: skynono Date: Tue, 17 Jun 2025 15:35:40 +0800 Subject: [PATCH 04/19] feat: add unsupported test case for kling channel --- controller/channel-test.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/controller/channel-test.go b/controller/channel-test.go index d162d8cf..c7e53c13 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -40,6 +40,9 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr if channel.Type == common.ChannelTypeSunoAPI { return errors.New("suno channel test is not supported"), nil } + if channel.Type == common.ChannelTypeKling { + return errors.New("kling channel test is not supported"), nil + } w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) From 1fed1ee5675c12ec93bbcd286081d0610f752de6 Mon Sep 17 00:00:00 2001 From: skynono Date: Thu, 19 Jun 2025 14:41:33 +0800 Subject: [PATCH 05/19] fix: unique channel models --- controller/model.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/controller/model.go b/controller/model.go index 134217a3..78bd32d6 100644 --- a/controller/model.go +++ b/controller/model.go @@ -2,6 +2,7 @@ package controller import ( "fmt" + "github.com/samber/lo" "net/http" "one-api/common" "one-api/constant" @@ -136,6 +137,9 @@ func init() { adaptor.Init(meta) channelId2Models[i] = adaptor.GetModelList() } + openAIModels = lo.UniqBy(openAIModels, func(m dto.OpenAIModels) string { + return m.Id + }) } func ListModels(c *gin.Context) { From fb4ff63bad031cce6742a9a3efe19d2193e9305b Mon Sep 17 00:00:00 2001 From: "Apple\\Apple" Date: Thu, 19 Jun 2025 15:17:05 +0800 Subject: [PATCH 06/19] =?UTF-8?q?=F0=9F=97=91=EF=B8=8F=20chore(custom=20ch?= =?UTF-8?q?annel):=20Remove=20custom=20channel=20support=20from=20upstream?= =?UTF-8?q?=20ratio=20sync?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Remove all custom channel functionality from the upstream ratio sync feature to simplify the codebase and focus on database-stored channels only. Changes: - Remove custom channel UI components and related state management - Remove custom channel testing and validation logic - Simplify ChannelSelectorModal by removing custom channel input fields - Update API payload to only include channel_ids, removing custom_channels - Remove custom channel processing logic from backend controller - Update import path for DEFAULT_ENDPOINT constant Files modified: - web/src/pages/Setting/Ratio/UpstreamRatioSync.js - web/src/components/settings/ChannelSelectorModal.js - controller/ratio_sync.go This change streamlines the ratio synchronization workflow by focusing solely on pre-configured database channels, reducing complexity and potential maintenance overhead. --- controller/ratio_sync.go | 1 - .../settings/ChannelSelectorModal.js | 40 +------ web/src/constants/common.constant.js | 2 + web/src/helpers/ratio.js | 20 ---- .../pages/Setting/Ratio/UpstreamRatioSync.js | 113 +----------------- 5 files changed, 7 insertions(+), 169 deletions(-) delete mode 100644 web/src/helpers/ratio.js diff --git a/controller/ratio_sync.go b/controller/ratio_sync.go index c7494b5b..fae0c59c 100644 --- a/controller/ratio_sync.go +++ b/controller/ratio_sync.go @@ -66,7 +66,6 @@ func FetchUpstreamRatios(c *gin.Context) { }) } } - upstreams = append(upstreams, req.CustomChannels...) var wg sync.WaitGroup ch := make(chan upstreamResult, len(upstreams)) diff --git a/web/src/components/settings/ChannelSelectorModal.js b/web/src/components/settings/ChannelSelectorModal.js index c393d97f..35059473 100644 --- a/web/src/components/settings/ChannelSelectorModal.js +++ b/web/src/components/settings/ChannelSelectorModal.js @@ -3,12 +3,10 @@ import { Modal, Transfer, Input, - Card, Space, - Button, Checkbox, } from '@douyinfe/semi-ui'; -import { IconPlus, IconClose } from '@douyinfe/semi-icons'; +import { IconClose } from '@douyinfe/semi-icons'; /** * ChannelSelectorModal @@ -20,21 +18,13 @@ export default function ChannelSelectorModal({ visible, onCancel, onOk, - // 渠道与选择 + // 渠道选择 allChannels = [], selectedChannelIds = [], setSelectedChannelIds, - // 自定义渠道 - customUrl, - setCustomUrl, - customEndpoint, - setCustomEndpoint, - customChannelTesting, - addCustomChannel, // 渠道端点 channelEndpoints, updateChannelEndpoint, - // 测试相关 }) { // Transfer 自定义渲染 const renderSourceItem = (item) => { @@ -107,32 +97,6 @@ export default function ChannelSelectorModal({ width={1000} > - - - - - - - - { - if (!baseUrl) return endpoint; - try { - return new URL(endpoint, baseUrl).toString(); - } catch (_) { - // fallback 处理不规范的 baseUrl - const cleanedBase = baseUrl.endsWith('/') ? baseUrl.slice(0, -1) : baseUrl; - const cleanedEndpoint = endpoint.startsWith('/') ? endpoint.slice(1) : endpoint; - return `${cleanedBase}/${cleanedEndpoint}`; - } -}; \ No newline at end of file diff --git a/web/src/pages/Setting/Ratio/UpstreamRatioSync.js b/web/src/pages/Setting/Ratio/UpstreamRatioSync.js index 518c6468..ecf5a1b9 100644 --- a/web/src/pages/Setting/Ratio/UpstreamRatioSync.js +++ b/web/src/pages/Setting/Ratio/UpstreamRatioSync.js @@ -11,11 +11,8 @@ import { RefreshCcw, CheckSquare, } from 'lucide-react'; -import { - DEFAULT_ENDPOINT, - buildEndpointUrl, -} from '../../../helpers/ratio'; import { API, showError, showSuccess, showWarning } from '../../../helpers'; +import { DEFAULT_ENDPOINT } from '../../../constants'; import { useTranslation } from 'react-i18next'; import { IllustrationNoResult, @@ -33,11 +30,6 @@ export default function UpstreamRatioSync(props) { const [allChannels, setAllChannels] = useState([]); const [selectedChannelIds, setSelectedChannelIds] = useState([]); - // 自定义渠道 - const [customUrl, setCustomUrl] = useState(''); - const [customEndpoint, setCustomEndpoint] = useState(DEFAULT_ENDPOINT); - const [customChannelTesting, setCustomChannelTesting] = useState(false); - // 渠道端点配置 const [channelEndpoints, setChannelEndpoints] = useState({}); // { channelId: endpoint } @@ -94,86 +86,6 @@ export default function UpstreamRatioSync(props) { } }; - // 测试自定义渠道 - const testCustomChannel = async () => { - if (!customUrl) { - showWarning(t('请输入渠道地址')); - return false; - } - - setCustomChannelTesting(true); - - try { - const url = buildEndpointUrl(customUrl, customEndpoint); - const client = { timeout: 10000 }; - - const response = await fetch(url, { - method: 'GET', - signal: AbortSignal.timeout(client.timeout) - }); - - if (response.ok) { - const data = await response.json(); - if (data.success) { - return true; - } else { - showError(t('测试失败') + `: ${data.message || t('响应格式错误')}`); - return false; - } - } else { - showError(t('测试失败') + `: HTTP ${response.status}`); - return false; - } - } catch (error) { - showError(t('测试失败') + `: ${error.message || t('请求超时')}`); - return false; - } finally { - setCustomChannelTesting(false); - } - }; - - // 添加自定义渠道 - const addCustomChannel = async () => { - if (!customUrl) { - showWarning(t('请输入渠道地址')); - return; - } - - // 先测试渠道 - const testResult = await testCustomChannel(); - if (!testResult) { - return; - } - - let hostname; - try { - hostname = new URL(customUrl).hostname; - } catch (e) { - hostname = customUrl; - } - - const customId = `custom_${Date.now()}`; - const newChannel = { - key: customId, - label: hostname, - value: customId, - disabled: false, - _originalData: { - id: customId, - name: hostname, - base_url: customUrl.endsWith('/') ? customUrl.slice(0, -1) : customUrl, - status: 1, - is_custom: true, - }, - }; - - setAllChannels([...allChannels, newChannel]); - setSelectedChannelIds([...selectedChannelIds, customId]); - setChannelEndpoints(prev => ({ ...prev, [customId]: customEndpoint })); - setCustomUrl(''); - showSuccess(t('测试成功,渠道添加成功')); - }; - // 确认选择渠道 const confirmChannelSelection = () => { const selected = allChannels @@ -193,18 +105,9 @@ export default function UpstreamRatioSync(props) { const fetchRatiosFromChannels = async (channelList) => { setSyncLoading(true); - // 分离数据库渠道和自定义渠道 - const dbChannels = channelList.filter(ch => !ch.is_custom); - const customChannels = channelList.filter(ch => ch.is_custom); - const payload = { - channel_ids: dbChannels.map(ch => parseInt(ch.id)), - custom_channels: customChannels.map(ch => ({ - name: ch.name, - base_url: ch.base_url, - endpoint: channelEndpoints[ch.id] || DEFAULT_ENDPOINT, - })), - timeout: 10 + channel_ids: channelList.map(ch => parseInt(ch.id)), + timeout: 10, }; try { @@ -391,12 +294,10 @@ export default function UpstreamRatioSync(props) { title: t('模型'), dataIndex: 'model', fixed: 'left', - width: 160, }, { title: t('倍率类型'), dataIndex: 'ratioType', - width: 140, render: (text) => { const typeMap = { model_ratio: t('模型倍率'), @@ -410,7 +311,6 @@ export default function UpstreamRatioSync(props) { { title: t('当前值'), dataIndex: 'current', - width: 100, render: (text) => ( {text !== null && text !== undefined ? text : t('未设置')} @@ -486,7 +386,6 @@ export default function UpstreamRatioSync(props) { {upName} ), dataIndex: upName, - width: 140, render: (_, record) => { const upstreamVal = record.upstreams?.[upName]; @@ -582,12 +481,6 @@ export default function UpstreamRatioSync(props) { allChannels={allChannels} selectedChannelIds={selectedChannelIds} setSelectedChannelIds={setSelectedChannelIds} - customUrl={customUrl} - setCustomUrl={setCustomUrl} - customEndpoint={customEndpoint} - setCustomEndpoint={setCustomEndpoint} - customChannelTesting={customChannelTesting} - addCustomChannel={addCustomChannel} channelEndpoints={channelEndpoints} updateChannelEndpoint={updateChannelEndpoint} /> From 67546f4b2adad40ed38f51d177d73daa728fd666 Mon Sep 17 00:00:00 2001 From: "Apple\\Apple" Date: Thu, 19 Jun 2025 16:05:50 +0800 Subject: [PATCH 07/19] =?UTF-8?q?=E2=9C=A8=20chore(ui):=20enhance=20channe?= =?UTF-8?q?l=20selector=20with=20status=20avatars=20and=20UI=20improvement?= =?UTF-8?q?s?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add visual status indicators and improve user experience for the upstream ratio sync channel selector modal. Features: - Add status-based avatar indicators for channels (enabled/disabled/auto-disabled) - Implement search functionality with text highlighting - Add endpoint configuration input for each channel - Optimize component structure with reusable ChannelInfo component UI Improvements: - Custom styling for transfer component items - Hide scrollbars for cleaner appearance in transfer lists - Responsive layout adjustments for channel information display - Color-coded avatars: green (enabled), red (disabled), amber (auto-disabled), grey (unknown) Code Quality: - Extract channel status configuration to constants - Create reusable ChannelInfo component to reduce code duplication - Implement proper search filtering for both channel names and URLs - Add consistent styling classes for transfer demo components Files modified: - web/src/components/settings/ChannelSelectorModal.js - web/src/pages/Setting/Ratio/UpstreamRatioSync.js - web/src/index.css This enhancement provides better visual feedback for channel status and improves the overall user experience when selecting channels for ratio synchronization. --- controller/ratio_sync.go | 2 +- .../settings/ChannelSelectorModal.js | 125 +++++++++++------- web/src/index.css | 68 ++++++++++ .../pages/Setting/Ratio/UpstreamRatioSync.js | 8 -- 4 files changed, 144 insertions(+), 59 deletions(-) diff --git a/controller/ratio_sync.go b/controller/ratio_sync.go index fae0c59c..490a2a74 100644 --- a/controller/ratio_sync.go +++ b/controller/ratio_sync.go @@ -49,7 +49,7 @@ func FetchUpstreamRatios(c *gin.Context) { req.Timeout = 10 } - // build upstream list from ids + custom + // build upstream list from ids var upstreams []dto.UpstreamDTO if len(req.ChannelIDs) > 0 { // convert []int64 -> []int for model function diff --git a/web/src/components/settings/ChannelSelectorModal.js b/web/src/components/settings/ChannelSelectorModal.js index 35059473..573329b3 100644 --- a/web/src/components/settings/ChannelSelectorModal.js +++ b/web/src/components/settings/ChannelSelectorModal.js @@ -1,92 +1,116 @@ -import React from 'react'; +import React, { useState } from 'react'; import { Modal, Transfer, Input, Space, Checkbox, + Avatar, + Highlight, } from '@douyinfe/semi-ui'; import { IconClose } from '@douyinfe/semi-icons'; -/** - * ChannelSelectorModal - * 负责选择同步渠道、测试与批量测试等 UI,纯展示组件。 - * 业务状态与动作通过 props 注入,保持可复用与可测试。 - */ +const CHANNEL_STATUS_CONFIG = { + 1: { color: 'green', text: '启用' }, + 2: { color: 'red', text: '禁用' }, + 3: { color: 'amber', text: '自禁' }, + default: { color: 'grey', text: '未知' } +}; + +const getChannelStatusConfig = (status) => { + return CHANNEL_STATUS_CONFIG[status] || CHANNEL_STATUS_CONFIG.default; +}; + export default function ChannelSelectorModal({ t, visible, onCancel, onOk, - // 渠道选择 allChannels = [], selectedChannelIds = [], setSelectedChannelIds, - // 渠道端点 channelEndpoints, updateChannelEndpoint, }) { - // Transfer 自定义渲染 - const renderSourceItem = (item) => { + const [searchText, setSearchText] = useState(''); + + const ChannelInfo = ({ item, showEndpoint = false, isSelected = false }) => { const channelId = item.key || item.value; const currentEndpoint = channelEndpoints[channelId]; const baseUrl = item._originalData?.base_url || ''; + const status = item._originalData?.status || 0; + const statusConfig = getChannelStatusConfig(status); return ( -
-
-
- - {item.label} - + <> + + {statusConfig.text} + +
+
+ {isSelected ? ( + item.label + ) : ( + + )}
-
- - {baseUrl} +
+ + {isSelected ? ( + baseUrl + ) : ( + + )} - updateChannelEndpoint(channelId, value)} - placeholder="/api/ratio_config" - className="flex-1 text-xs" - style={{ fontSize: '12px' }} - /> + {showEndpoint && ( + updateChannelEndpoint(channelId, value)} + placeholder="/api/ratio_config" + className="flex-1 text-xs" + style={{ fontSize: '12px' }} + /> + )} + {isSelected && !showEndpoint && ( + + {currentEndpoint} + + )}
+ + ); + }; + + const renderSourceItem = (item) => { + return ( +
+ + +
); }; const renderSelectedItem = (item) => { - const channelId = item.key || item.value; - const currentEndpoint = channelEndpoints[channelId]; - const baseUrl = item._originalData?.base_url || ''; - return ( -
-
-
- {item.label} - -
-
- - {baseUrl} - - - {currentEndpoint} - -
-
+
+ +
); }; - const channelFilter = (input, item) => item.label.toLowerCase().includes(input.toLowerCase()); + const channelFilter = (input, item) => { + const searchLower = input.toLowerCase(); + return item.label.toLowerCase().includes(searchLower) || + (item._originalData?.base_url || '').toLowerCase().includes(searchLower); + }; return ( .semi-table-row { border-bottom: 1px solid rgba(0, 0, 0, 0.1); } +} + +/* ==================== 同步倍率 - 渠道选择器 ==================== */ + +.components-transfer-source-item, +.components-transfer-selected-item { + display: flex; + align-items: center; + padding: 8px; +} + +.semi-transfer-left-list, +.semi-transfer-right-list { + -ms-overflow-style: none; + scrollbar-width: none; +} + +.semi-transfer-left-list::-webkit-scrollbar, +.semi-transfer-right-list::-webkit-scrollbar { + display: none; +} + +.components-transfer-source-item .semi-checkbox, +.components-transfer-selected-item .semi-checkbox { + display: flex; + align-items: center; + width: 100%; +} + +.components-transfer-source-item .semi-avatar, +.components-transfer-selected-item .semi-avatar { + margin-right: 12px; + flex-shrink: 0; +} + +.components-transfer-source-item .info, +.components-transfer-selected-item .info { + flex: 1; + overflow: hidden; + display: flex; + flex-direction: column; + justify-content: center; +} + +.components-transfer-source-item .name, +.components-transfer-selected-item .name { + font-weight: 500; + white-space: nowrap; + overflow: hidden; + text-overflow: ellipsis; +} + +.components-transfer-source-item .email, +.components-transfer-selected-item .email { + font-size: 12px; + color: var(--semi-color-text-2); + display: flex; + align-items: center; +} + +.components-transfer-selected-item .semi-icon-close { + margin-left: 8px; + cursor: pointer; + color: var(--semi-color-text-2); +} + +.components-transfer-selected-item .semi-icon-close:hover { + color: var(--semi-color-text-0); } \ No newline at end of file diff --git a/web/src/pages/Setting/Ratio/UpstreamRatioSync.js b/web/src/pages/Setting/Ratio/UpstreamRatioSync.js index ecf5a1b9..2e12fd3b 100644 --- a/web/src/pages/Setting/Ratio/UpstreamRatioSync.js +++ b/web/src/pages/Setting/Ratio/UpstreamRatioSync.js @@ -42,14 +42,6 @@ export default function UpstreamRatioSync(props) { const [currentPage, setCurrentPage] = useState(1); const [pageSize, setPageSize] = useState(10); - // 当前倍率快照 - const currentRatiosSnapshot = useMemo(() => ({ - model_ratio: JSON.parse(props.options.ModelRatio || '{}'), - completion_ratio: JSON.parse(props.options.CompletionRatio || '{}'), - cache_ratio: JSON.parse(props.options.CacheRatio || '{}'), - model_price: JSON.parse(props.options.ModelPrice || '{}'), - }), [props.options]); - // 获取所有渠道 const fetchAllChannels = async () => { setLoading(true); From a9f98c5d392799096c4bd6904f39b66425d68cea Mon Sep 17 00:00:00 2001 From: "Apple\\Apple" Date: Thu, 19 Jun 2025 18:38:43 +0800 Subject: [PATCH 08/19] =?UTF-8?q?=F0=9F=9B=A0=EF=B8=8F=20chore(ratio-sync)?= =?UTF-8?q?:=20improve=20upstream=20ratio=20comparison=20&=20output=20clea?= =?UTF-8?q?nliness?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary 1. Consider “both unset” as identical • When both localValue and upstreamValue are nil, mark upstreamValue as "same" to avoid showing “Not set”. 2. Exclude fully-synced upstream channels from result • Scan `differences` to detect channels that contain at least one divergent value. • Remove channels whose every ratio is either `"same"` or `nil`, so the frontend only receives actionable discrepancies. Why These changes reduce visual noise in the Upstream Ratio Sync table, making it easier for admins to focus on models requiring attention. No functional regressions or breaking API changes are introduced. --- controller/ratio_sync.go | 48 +++++++------- web/src/i18n/locales/en.json | 24 ++++++- .../pages/Setting/Ratio/UpstreamRatioSync.js | 63 ++++++++----------- 3 files changed, 74 insertions(+), 61 deletions(-) diff --git a/controller/ratio_sync.go b/controller/ratio_sync.go index 490a2a74..368f92dd 100644 --- a/controller/ratio_sync.go +++ b/controller/ratio_sync.go @@ -29,7 +29,6 @@ type DifferenceItem struct { Upstreams map[string]interface{} `json:"upstreams"` // 上游值:具体值/"same"/null } -// SyncableChannel 可同步的渠道信息 type SyncableChannel struct { ID int `json:"id"` Name string `json:"name"` @@ -37,7 +36,6 @@ type SyncableChannel struct { Status int `json:"status"` } -// FetchUpstreamRatios 后端并发拉取上游倍率 func FetchUpstreamRatios(c *gin.Context) { var req dto.UpstreamRequest if err := c.ShouldBindJSON(&req); err != nil { @@ -49,10 +47,8 @@ func FetchUpstreamRatios(c *gin.Context) { req.Timeout = 10 } - // build upstream list from ids var upstreams []dto.UpstreamDTO if len(req.ChannelIDs) > 0 { - // convert []int64 -> []int for model function intIds := make([]int, 0, len(req.ChannelIDs)) for _, id64 := range req.ChannelIDs { intIds = append(intIds, int(id64)) @@ -62,7 +58,7 @@ func FetchUpstreamRatios(c *gin.Context) { upstreams = append(upstreams, dto.UpstreamDTO{ Name: ch.Name, BaseURL: ch.GetBaseURL(), - Endpoint: "", // assume default endpoint + Endpoint: "", }) } } @@ -110,7 +106,6 @@ func FetchUpstreamRatios(c *gin.Context) { wg.Wait() close(ch) - // 本地倍率配置 localData := ratio_setting.GetExposedData() var testResults []dto.TestResult @@ -138,7 +133,6 @@ func FetchUpstreamRatios(c *gin.Context) { } } - // 构建差异化数据 differences := buildDifferences(localData, successfulChannels) c.JSON(http.StatusOK, gin.H{ @@ -150,7 +144,6 @@ func FetchUpstreamRatios(c *gin.Context) { }) } -// buildDifferences 构建差异化数据,只返回有意义的差异 func buildDifferences(localData map[string]any, successfulChannels []struct { name string data map[string]any @@ -158,10 +151,8 @@ func buildDifferences(localData map[string]any, successfulChannels []struct { differences := make(map[string]map[string]dto.DifferenceItem) ratioTypes := []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"} - // 收集所有模型名称 allModels := make(map[string]struct{}) - // 从本地数据收集模型名称 for _, ratioType := range ratioTypes { if localRatioAny, ok := localData[ratioType]; ok { if localRatio, ok := localRatioAny.(map[string]float64); ok { @@ -172,7 +163,6 @@ func buildDifferences(localData map[string]any, successfulChannels []struct { } } - // 从上游数据收集模型名称 for _, channel := range successfulChannels { for _, ratioType := range ratioTypes { if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok { @@ -183,10 +173,8 @@ func buildDifferences(localData map[string]any, successfulChannels []struct { } } - // 对每个模型和每个比率类型进行分析 for modelName := range allModels { for _, ratioType := range ratioTypes { - // 获取本地值 var localValue interface{} = nil if localRatioAny, ok := localData[ratioType]; ok { if localRatio, ok := localRatioAny.(map[string]float64); ok { @@ -196,7 +184,6 @@ func buildDifferences(localData map[string]any, successfulChannels []struct { } } - // 收集上游值 upstreamValues := make(map[string]interface{}) hasUpstreamValue := false hasDifference := false @@ -209,7 +196,6 @@ func buildDifferences(localData map[string]any, successfulChannels []struct { upstreamValue = val hasUpstreamValue = true - // 检查是否与本地值不同 if localValue != nil && localValue != val { hasDifference = true } else if localValue == val { @@ -217,8 +203,10 @@ func buildDifferences(localData map[string]any, successfulChannels []struct { } } } + if upstreamValue == nil && localValue == nil { + upstreamValue = "same" + } - // 如果本地值为空但上游有值,这也是差异 if localValue == nil && upstreamValue != nil && upstreamValue != "same" { hasDifference = true } @@ -226,17 +214,13 @@ func buildDifferences(localData map[string]any, successfulChannels []struct { upstreamValues[channel.name] = upstreamValue } - // 应用过滤逻辑 shouldInclude := false if localValue != nil { - // 规则1: 本地值存在,至少有一个上游与本地值不同 if hasDifference { shouldInclude = true } - // 规则2: 本地值存在,但所有上游都未设置 - 不包含 } else { - // 规则3: 本地值不存在,至少有一个上游设置了值 if hasUpstreamValue { shouldInclude = true } @@ -254,10 +238,31 @@ func buildDifferences(localData map[string]any, successfulChannels []struct { } } + channelHasDiff := make(map[string]bool) + for _, ratioMap := range differences { + for _, item := range ratioMap { + for chName, val := range item.Upstreams { + if val != nil && val != "same" { + channelHasDiff[chName] = true + } + } + } + } + + for modelName, ratioMap := range differences { + for ratioType, item := range ratioMap { + for chName := range item.Upstreams { + if !channelHasDiff[chName] { + delete(item.Upstreams, chName) + } + } + differences[modelName][ratioType] = item + } + } + return differences } -// GetSyncableChannels 获取可用于倍率同步的渠道(base_url 不为空的渠道) func GetSyncableChannels(c *gin.Context) { channels, err := model.GetAllChannels(0, 0, true, false) if err != nil { @@ -270,7 +275,6 @@ func GetSyncableChannels(c *gin.Context) { var syncableChannels []dto.SyncableChannel for _, channel := range channels { - // 只返回 base_url 不为空的渠道 if channel.GetBaseURL() != "" { syncableChannels = append(syncableChannels, dto.SyncableChannel{ ID: channel.Id, diff --git a/web/src/i18n/locales/en.json b/web/src/i18n/locales/en.json index fc80f9c1..b8e1afd8 100644 --- a/web/src/i18n/locales/en.json +++ b/web/src/i18n/locales/en.json @@ -1665,5 +1665,27 @@ "确定清除所有失效兑换码?": "Are you sure you want to clear all invalid redemption codes?", "将删除已使用、已禁用及过期的兑换码,此操作不可撤销。": "This will delete all used, disabled, and expired redemption codes, this operation cannot be undone.", "选择过期时间(可选,留空为永久)": "Select expiration time (optional, leave blank for permanent)", - "请输入备注(仅管理员可见)": "Please enter a remark (only visible to administrators)" + "请输入备注(仅管理员可见)": "Please enter a remark (only visible to administrators)", + "上游倍率同步": "Upstream ratio synchronization", + "获取渠道失败:": "Failed to get channels: ", + "请至少选择一个渠道": "Please select at least one channel", + "获取倍率失败:": "Failed to get ratios: ", + "后端请求失败": "Backend request failed", + "部分渠道测试失败:": "Some channels failed to test: ", + "已与上游倍率完全一致,无需同步": "The upstream ratio is completely consistent, no synchronization is required", + "请求后端接口失败:": "Failed to request the backend interface: ", + "同步成功": "Synchronization successful", + "部分保存失败": "Some settings failed to save", + "保存失败": "Save failed", + "选择同步渠道": "Select synchronization channel", + "应用同步": "Apply synchronization", + "倍率类型": "Ratio type", + "当前值": "Current value", + "上游值": "Upstream value", + "差异": "Difference", + "搜索渠道名称或地址": "Search channel name or address", + "缓存倍率": "Cache ratio", + "暂无差异化倍率显示": "No differential ratio display", + "请先选择同步渠道": "Please select the synchronization channel first", + "与本地相同": "Same as local" } \ No newline at end of file diff --git a/web/src/pages/Setting/Ratio/UpstreamRatioSync.js b/web/src/pages/Setting/Ratio/UpstreamRatioSync.js index 2e12fd3b..aae6d9f3 100644 --- a/web/src/pages/Setting/Ratio/UpstreamRatioSync.js +++ b/web/src/pages/Setting/Ratio/UpstreamRatioSync.js @@ -11,7 +11,7 @@ import { RefreshCcw, CheckSquare, } from 'lucide-react'; -import { API, showError, showSuccess, showWarning } from '../../../helpers'; +import { API, showError, showSuccess, showWarning, stringToColor } from '../../../helpers'; import { DEFAULT_ENDPOINT } from '../../../constants'; import { useTranslation } from 'react-i18next'; import { @@ -35,14 +35,12 @@ export default function UpstreamRatioSync(props) { // 差异数据和测试结果 const [differences, setDifferences] = useState({}); - const [testResults, setTestResults] = useState([]); const [resolutions, setResolutions] = useState({}); // 分页相关状态 const [currentPage, setCurrentPage] = useState(1); const [pageSize, setPageSize] = useState(10); - // 获取所有渠道 const fetchAllChannels = async () => { setLoading(true); try { @@ -51,18 +49,16 @@ export default function UpstreamRatioSync(props) { if (res.data.success) { const channels = res.data.data || []; - // 转换为Transfer组件所需格式 const transferData = channels.map(channel => ({ key: channel.id, label: channel.name, value: channel.id, - disabled: false, // 所有渠道都可以选择 + disabled: false, _originalData: channel, })); setAllChannels(transferData); - // 初始化端点配置 const initialEndpoints = {}; transferData.forEach(channel => { initialEndpoints[channel.key] = DEFAULT_ENDPOINT; @@ -78,7 +74,6 @@ export default function UpstreamRatioSync(props) { } }; - // 确认选择渠道 const confirmChannelSelection = () => { const selected = allChannels .filter(ch => selectedChannelIds.includes(ch.value)) @@ -93,7 +88,6 @@ export default function UpstreamRatioSync(props) { fetchRatiosFromChannels(selected); }; - // 从选定渠道获取倍率 const fetchRatiosFromChannels = async (channelList) => { setSyncLoading(true); @@ -113,17 +107,14 @@ export default function UpstreamRatioSync(props) { const { differences = {}, test_results = [] } = res.data.data; - // 显示测试结果 const errorResults = test_results.filter(r => r.status === 'error'); if (errorResults.length > 0) { showWarning(t('部分渠道测试失败:') + errorResults.map(r => `${r.name}: ${r.error}`).join(', ')); } setDifferences(differences); - setTestResults(test_results); setResolutions({}); - // 判断是否有差异 if (Object.keys(differences).length === 0) { showSuccess(t('已与上游倍率完全一致,无需同步')); } @@ -134,7 +125,6 @@ export default function UpstreamRatioSync(props) { } }; - // 解决冲突/选择值 const selectValue = (model, ratioType, value) => { setResolutions(prev => ({ ...prev, @@ -145,7 +135,6 @@ export default function UpstreamRatioSync(props) { })); }; - // 应用同步 const applySync = async () => { const currentRatios = { ModelRatio: JSON.parse(props.options.ModelRatio || '{}'), @@ -154,7 +143,6 @@ export default function UpstreamRatioSync(props) { ModelPrice: JSON.parse(props.options.ModelPrice || '{}'), }; - // 应用已选择的值 Object.entries(resolutions).forEach(([model, ratios]) => { Object.entries(ratios).forEach(([ratioType, value]) => { const optionKey = ratioType @@ -165,7 +153,6 @@ export default function UpstreamRatioSync(props) { }); }); - // 保存到后端 setLoading(true); try { const updates = Object.entries(currentRatios).map(([key, value]) => @@ -180,11 +167,26 @@ export default function UpstreamRatioSync(props) { if (results.every(res => res.data.success)) { showSuccess(t('同步成功')); props.refresh(); - // 清空状态 - setDifferences({}); - setTestResults([]); + + setDifferences(prevDifferences => { + const newDifferences = { ...prevDifferences }; + + Object.entries(resolutions).forEach(([model, ratios]) => { + Object.keys(ratios).forEach(ratioType => { + if (newDifferences[model] && newDifferences[model][ratioType]) { + delete newDifferences[model][ratioType]; + + if (Object.keys(newDifferences[model]).length === 0) { + delete newDifferences[model]; + } + } + }); + }); + + return newDifferences; + }); + setResolutions({}); - setSelectedChannelIds([]); } else { showError(t('部分保存失败')); } @@ -195,14 +197,12 @@ export default function UpstreamRatioSync(props) { } }; - // 计算当前页显示的数据 const getCurrentPageData = (dataSource) => { const startIndex = (currentPage - 1) * pageSize; const endIndex = startIndex + pageSize; return dataSource.slice(startIndex, endIndex); }; - // 渲染表格头部 const renderHeader = () => (
@@ -219,7 +219,6 @@ export default function UpstreamRatioSync(props) { {(() => { - // 检查是否有选择可应用的值 const hasSelections = Object.keys(resolutions).length > 0; return ( @@ -239,9 +238,7 @@ export default function UpstreamRatioSync(props) {
); - // 渲染差异表格 const renderDifferenceTable = () => { - // 构建数据源 const dataSource = useMemo(() => { const tmp = []; @@ -260,7 +257,6 @@ export default function UpstreamRatioSync(props) { return tmp; }, [differences]); - // 收集所有上游渠道名称 const upstreamNames = useMemo(() => { const set = new Set(); dataSource.forEach((row) => { @@ -274,13 +270,12 @@ export default function UpstreamRatioSync(props) { } darkModeImage={} - description={Object.keys(differences).length === 0 ? t('已与上游倍率完全一致') : t('请先选择同步渠道')} + description={Object.keys(differences).length === 0 ? t('暂无差异化倍率显示') : t('请先选择同步渠道')} style={{ padding: 30 }} /> ); } - // 列定义 const columns = [ { title: t('模型'), @@ -297,7 +292,7 @@ export default function UpstreamRatioSync(props) { cache_ratio: t('缓存倍率'), model_price: t('固定价格'), }; - return {typeMap[text] || text}; + return {typeMap[text] || text}; }, }, { @@ -309,16 +304,13 @@ export default function UpstreamRatioSync(props) { ), }, - // 动态上游列 ...upstreamNames.map((upName) => { - // 计算该渠道的全选状态 const channelStats = (() => { - let selectableCount = 0; // 可选择的项目数量 - let selectedCount = 0; // 已选择的项目数量 + let selectableCount = 0; + let selectedCount = 0; dataSource.forEach((row) => { const upstreamVal = row.upstreams?.[upName]; - // 只有具体数值的才是可选择的(不是null、undefined或"same") if (upstreamVal !== null && upstreamVal !== undefined && upstreamVal !== 'same') { selectableCount++; const isSelected = resolutions[row.model]?.[row.ratioType] === upstreamVal; @@ -337,7 +329,6 @@ export default function UpstreamRatioSync(props) { }; })(); - // 处理全选/取消全选 const handleBulkSelect = (checked) => { setResolutions((prev) => { const newRes = { ...prev }; @@ -346,11 +337,9 @@ export default function UpstreamRatioSync(props) { const upstreamVal = row.upstreams?.[upName]; if (upstreamVal !== null && upstreamVal !== undefined && upstreamVal !== 'same') { if (checked) { - // 选择该值 if (!newRes[row.model]) newRes[row.model] = {}; newRes[row.model][row.ratioType] = upstreamVal; } else { - // 取消选择该值 if (newRes[row.model]) { delete newRes[row.model][row.ratioType]; if (Object.keys(newRes[row.model]).length === 0) { @@ -389,7 +378,6 @@ export default function UpstreamRatioSync(props) { return {t('与本地相同')}; } - // 有具体值,可以选择 const isSelected = resolutions[record.model]?.[record.ratioType] === upstreamVal; return ( @@ -454,7 +442,6 @@ export default function UpstreamRatioSync(props) { ); }; - // 更新渠道端点 const updateChannelEndpoint = useCallback((channelId, endpoint) => { setChannelEndpoints(prev => ({ ...prev, [channelId]: endpoint })); }, []); From 458472f3e2e3bb84946ee27aaa92d1fd63b6713a Mon Sep 17 00:00:00 2001 From: "Apple\\Apple" Date: Thu, 19 Jun 2025 18:54:46 +0800 Subject: [PATCH 09/19] =?UTF-8?q?=F0=9F=94=8D=20feat(ratio-sync):=20add=20?= =?UTF-8?q?fuzzy=20model=20search=20&=20enhance=20empty-state=20UX?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary 1. Add model name search box • Introduce Semi UI `Input` with `IconSearch` prefix next to the “Apply Sync” button. • Support case-insensitive fuzzy matching of model names. • Real-time filtering, pagination and bulk-select logic now work on filtered data. 2. Improve empty state handling • Add `hasSynced` flag to distinguish “not synced yet” from “synced with no differences”. • Display messages: – “Please select sync channels” when no sync has been performed. – “No differences found” when a sync completed with zero discrepancies. – “No matching model found” when search yields no results. 3. UI tweaks • Replace lucide-react `Search` icon with Semi UI `IconSearch` for visual consistency. • Keep responsive width and clearable input for better usability. Why These changes allow admins to quickly locate specific models and provide accurate feedback on the sync status, greatly improving the usability of the Upstream Ratio Sync page. --- web/src/i18n/locales/en.json | 3 +- .../pages/Setting/Ratio/UpstreamRatioSync.js | 53 +++++++++++++++---- 2 files changed, 46 insertions(+), 10 deletions(-) diff --git a/web/src/i18n/locales/en.json b/web/src/i18n/locales/en.json index b8e1afd8..ab793364 100644 --- a/web/src/i18n/locales/en.json +++ b/web/src/i18n/locales/en.json @@ -1687,5 +1687,6 @@ "缓存倍率": "Cache ratio", "暂无差异化倍率显示": "No differential ratio display", "请先选择同步渠道": "Please select the synchronization channel first", - "与本地相同": "Same as local" + "与本地相同": "Same as local", + "未找到匹配的模型": "No matching model found" } \ No newline at end of file diff --git a/web/src/pages/Setting/Ratio/UpstreamRatioSync.js b/web/src/pages/Setting/Ratio/UpstreamRatioSync.js index aae6d9f3..f83e0cdc 100644 --- a/web/src/pages/Setting/Ratio/UpstreamRatioSync.js +++ b/web/src/pages/Setting/Ratio/UpstreamRatioSync.js @@ -6,7 +6,9 @@ import { Empty, Checkbox, Form, + Input, } from '@douyinfe/semi-ui'; +import { IconSearch } from '@douyinfe/semi-icons'; import { RefreshCcw, CheckSquare, @@ -37,10 +39,16 @@ export default function UpstreamRatioSync(props) { const [differences, setDifferences] = useState({}); const [resolutions, setResolutions] = useState({}); + // 是否已经执行过同步 + const [hasSynced, setHasSynced] = useState(false); + // 分页相关状态 const [currentPage, setCurrentPage] = useState(1); const [pageSize, setPageSize] = useState(10); + // 搜索相关状态 + const [searchKeyword, setSearchKeyword] = useState(''); + const fetchAllChannels = async () => { setLoading(true); try { @@ -114,6 +122,7 @@ export default function UpstreamRatioSync(props) { setDifferences(differences); setResolutions({}); + setHasSynced(true); if (Object.keys(differences).length === 0) { showSuccess(t('已与上游倍率完全一致,无需同步')); @@ -233,6 +242,15 @@ export default function UpstreamRatioSync(props) { ); })()} + + } + placeholder={t('搜索模型名称')} + value={searchKeyword} + onChange={setSearchKeyword} + className="!rounded-full w-full md:w-64 mt-2" + showClear + />
@@ -257,20 +275,37 @@ export default function UpstreamRatioSync(props) { return tmp; }, [differences]); + const filteredDataSource = useMemo(() => { + if (!searchKeyword.trim()) { + return dataSource; + } + + const keyword = searchKeyword.toLowerCase().trim(); + return dataSource.filter(item => + item.model.toLowerCase().includes(keyword) + ); + }, [dataSource, searchKeyword]); + const upstreamNames = useMemo(() => { const set = new Set(); - dataSource.forEach((row) => { + filteredDataSource.forEach((row) => { Object.keys(row.upstreams || {}).forEach((name) => set.add(name)); }); return Array.from(set); - }, [dataSource]); + }, [filteredDataSource]); - if (dataSource.length === 0) { + if (filteredDataSource.length === 0) { return ( } darkModeImage={} - description={Object.keys(differences).length === 0 ? t('暂无差异化倍率显示') : t('请先选择同步渠道')} + description={ + searchKeyword.trim() + ? t('未找到匹配的模型') + : (Object.keys(differences).length === 0 ? + (hasSynced ? t('暂无差异化倍率显示') : t('请先选择同步渠道')) + : t('请先选择同步渠道')) + } style={{ padding: 30 }} /> ); @@ -309,7 +344,7 @@ export default function UpstreamRatioSync(props) { let selectableCount = 0; let selectedCount = 0; - dataSource.forEach((row) => { + filteredDataSource.forEach((row) => { const upstreamVal = row.upstreams?.[upName]; if (upstreamVal !== null && upstreamVal !== undefined && upstreamVal !== 'same') { selectableCount++; @@ -333,7 +368,7 @@ export default function UpstreamRatioSync(props) { setResolutions((prev) => { const newRes = { ...prev }; - dataSource.forEach((row) => { + filteredDataSource.forEach((row) => { const upstreamVal = row.upstreams?.[upName]; if (upstreamVal !== null && upstreamVal !== undefined && upstreamVal !== 'same') { if (checked) { @@ -412,17 +447,17 @@ export default function UpstreamRatioSync(props) { return (
t('第 {{start}} - {{end}} 条,共 {{total}} 条', { start: page.currentStart, end: page.currentEnd, - total: dataSource.length, + total: filteredDataSource.length, }), pageSizeOptions: ['5', '10', '20', '50'], onChange: (page, size) => { From 150c506ece9cee198f396e7059f5a4746883ec13 Mon Sep 17 00:00:00 2001 From: "Apple\\Apple" Date: Thu, 19 Jun 2025 19:55:51 +0800 Subject: [PATCH 10/19] =?UTF-8?q?=F0=9F=9A=80=20chore(controller,=20dto):?= =?UTF-8?q?=20elevate=20ratio-sync=20feature=20to=20production=20readiness?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit WHAT’S NEW • controller/ratio_sync.go – Deleted unused local structs (TestResult, DifferenceItem, SyncableChannel). – Centralised config with constants: defaultTimeoutSeconds, defaultEndpoint, maxConcurrentFetches, ratioTypes. – Replaced magic numbers; added semaphore-based concurrency limit and shared http.Client (with TLS & Expect-Continue timeouts). – Added comprehensive error handling and context-aware logging via common.Log* helpers. – Checked DB errors from GetChannelsByIds; early-return on failures or empty upstream list. – Removed custom-channel support; logic now relies solely on ChannelIDs. – Minor clean-ups: import grouping, string trimming, endpoint normalisation. • dto/ratio_sync.go – Simplified UpstreamRequest: dropped unused CustomChannels field. WHY These improvements harden the ratio-sync endpoint for production use by preventing silent failures, controlling resource usage, and making behaviour configurable and observable. HOW No business logic change—only structural refactor, logging, and safeguards—so existing API contracts (aside from removed custom_channels) remain intact. --- controller/ratio_sync.go | 97 ++++++++++++++++++++++++++-------------- dto/ratio_sync.go | 5 +-- 2 files changed, 65 insertions(+), 37 deletions(-) diff --git a/controller/ratio_sync.go b/controller/ratio_sync.go index 368f92dd..f749f384 100644 --- a/controller/ratio_sync.go +++ b/controller/ratio_sync.go @@ -1,41 +1,35 @@ package controller import ( + "context" "encoding/json" "net/http" - "one-api/model" - "one-api/setting/ratio_setting" - "one-api/dto" + "strings" "sync" "time" + "one-api/common" + "one-api/dto" + "one-api/model" + "one-api/setting/ratio_setting" + "github.com/gin-gonic/gin" ) +const ( + defaultTimeoutSeconds = 10 + defaultEndpoint = "/api/ratio_config" + maxConcurrentFetches = 8 +) + +var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"} + type upstreamResult struct { Name string `json:"name"` Data map[string]any `json:"data,omitempty"` Err string `json:"err,omitempty"` } -type TestResult struct { - Name string `json:"name"` - Status string `json:"status"` - Error string `json:"error,omitempty"` -} - -type DifferenceItem struct { - Current interface{} `json:"current"` // 当前本地值,可能为null - Upstreams map[string]interface{} `json:"upstreams"` // 上游值:具体值/"same"/null -} - -type SyncableChannel struct { - ID int `json:"id"` - Name string `json:"name"` - BaseURL string `json:"base_url"` - Status int `json:"status"` -} - func FetchUpstreamRatios(c *gin.Context) { var req dto.UpstreamRequest if err := c.ShouldBindJSON(&req); err != nil { @@ -44,45 +38,80 @@ func FetchUpstreamRatios(c *gin.Context) { } if req.Timeout <= 0 { - req.Timeout = 10 + req.Timeout = defaultTimeoutSeconds } var upstreams []dto.UpstreamDTO + if len(req.ChannelIDs) > 0 { intIds := make([]int, 0, len(req.ChannelIDs)) for _, id64 := range req.ChannelIDs { intIds = append(intIds, int(id64)) } - dbChannels, _ := model.GetChannelsByIds(intIds) - for _, ch := range dbChannels { - upstreams = append(upstreams, dto.UpstreamDTO{ - Name: ch.Name, - BaseURL: ch.GetBaseURL(), - Endpoint: "", - }) + dbChannels, err := model.GetChannelsByIds(intIds) + if err != nil { + common.LogError(c.Request.Context(), "failed to query channels: "+err.Error()) + c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"}) + return } + for _, ch := range dbChannels { + if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") { + upstreams = append(upstreams, dto.UpstreamDTO{ + Name: ch.Name, + BaseURL: strings.TrimRight(base, "/"), + Endpoint: "", + }) + } + } + } + + if len(upstreams) == 0 { + c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"}) + return } var wg sync.WaitGroup ch := make(chan upstreamResult, len(upstreams)) + sem := make(chan struct{}, maxConcurrentFetches) + + client := &http.Client{Transport: &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second}} + for _, chn := range upstreams { wg.Add(1) go func(chItem dto.UpstreamDTO) { defer wg.Done() + + sem <- struct{}{} + defer func() { <-sem }() + endpoint := chItem.Endpoint if endpoint == "" { - endpoint = "/api/ratio_config" + endpoint = defaultEndpoint + } else if !strings.HasPrefix(endpoint, "/") { + endpoint = "/" + endpoint } - url := chItem.BaseURL + endpoint - client := http.Client{Timeout: time.Duration(req.Timeout) * time.Second} - resp, err := client.Get(url) + fullURL := chItem.BaseURL + endpoint + + ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second) + defer cancel() + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil) if err != nil { + common.LogWarn(c.Request.Context(), "build request failed: "+err.Error()) + ch <- upstreamResult{Name: chItem.Name, Err: err.Error()} + return + } + + resp, err := client.Do(httpReq) + if err != nil { + common.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+err.Error()) ch <- upstreamResult{Name: chItem.Name, Err: err.Error()} return } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { + common.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status) ch <- upstreamResult{Name: chItem.Name, Err: resp.Status} return } @@ -92,6 +121,7 @@ func FetchUpstreamRatios(c *gin.Context) { Message string `json:"message"` } if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { + common.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error()) ch <- upstreamResult{Name: chItem.Name, Err: err.Error()} return } @@ -149,7 +179,6 @@ func buildDifferences(localData map[string]any, successfulChannels []struct { data map[string]any }) map[string]map[string]dto.DifferenceItem { differences := make(map[string]map[string]dto.DifferenceItem) - ratioTypes := []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"} allModels := make(map[string]struct{}) diff --git a/dto/ratio_sync.go b/dto/ratio_sync.go index 4f2fe06d..55a89025 100644 --- a/dto/ratio_sync.go +++ b/dto/ratio_sync.go @@ -19,9 +19,8 @@ type UpstreamDTO struct { } type UpstreamRequest struct { - ChannelIDs []int64 `json:"channel_ids"` - CustomChannels []UpstreamDTO `json:"custom_channels"` - Timeout int `json:"timeout"` + ChannelIDs []int64 `json:"channel_ids"` + Timeout int `json:"timeout"` } // TestResult 上游测试连通性结果 From b087b20bac52f505e409cb23915f5c9fee6845dc Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Fri, 20 Jun 2025 14:53:27 +0800 Subject: [PATCH 11/19] refactor: update error handling in ClaudeHelper and GeminiHelper --- relay/claude_handler.go | 2 +- relay/relay-gemini.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/relay/claude_handler.go b/relay/claude_handler.go index e8805255..567378fb 100644 --- a/relay/claude_handler.go +++ b/relay/claude_handler.go @@ -126,7 +126,7 @@ func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) { var httpResp *http.Response resp, err := adaptor.DoRequest(c, relayInfo, requestBody) if err != nil { - return service.ClaudeErrorWrapperLocal(err, "do_request_failed", http.StatusInternalServerError) + return service.ClaudeErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } if resp != nil { diff --git a/relay/relay-gemini.go b/relay/relay-gemini.go index 80e5a694..455b31b7 100644 --- a/relay/relay-gemini.go +++ b/relay/relay-gemini.go @@ -162,7 +162,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { resp, err := adaptor.DoRequest(c, relayInfo, bytes.NewReader(requestBody)) if err != nil { common.LogError(c, "Do gemini request failed: "+err.Error()) - return service.OpenAIErrorWrapperLocal(err, "do_request_failed", http.StatusInternalServerError) + return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } statusCodeMappingStr := c.GetString("status_code_mapping") From d3286893c4b1a970e0ee7d5d9bcad92daa2e30f2 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Fri, 20 Jun 2025 16:02:23 +0800 Subject: [PATCH 12/19] feat: implement new handlers for audio, image, embedding, and responses processing - Added new handlers: AudioHelper, ImageHelper, EmbeddingHelper, and ResponsesHelper to manage respective requests. - Updated ModelMappedHelper to accept request parameters for better model mapping. - Enhanced error handling and validation across new handlers to ensure robust request processing. - Introduced support for new relay formats in relay_info and updated relevant functions accordingly. --- controller/channel-test.go | 2 +- relay/{relay-audio.go => audio_handler.go} | 6 +-- relay/channel/gemini/adaptor.go | 6 +-- relay/channel/gemini/relay-gemini.go | 2 +- relay/claude_handler.go | 4 +- relay/common/relay_info.go | 45 ++++++++++++++++--- ...elay_embedding.go => embedding_handler.go} | 6 +-- relay/{relay-gemini.go => gemini_handler.go} | 4 +- relay/helper/model_mapped.go | 40 ++++++++++++++++- relay/{relay-image.go => image_handler.go} | 6 +-- relay/relay-text.go | 4 +- relay/{relay_rerank.go => rerank_handler.go} | 4 +- ...elay-responses.go => responses_handler.go} | 4 +- 13 files changed, 95 insertions(+), 38 deletions(-) rename relay/{relay-audio.go => audio_handler.go} (96%) rename relay/{relay_embedding.go => embedding_handler.go} (96%) rename relay/{relay-gemini.go => gemini_handler.go} (98%) rename relay/{relay-image.go => image_handler.go} (98%) rename relay/{relay_rerank.go => rerank_handler.go} (97%) rename relay/{relay-responses.go => responses_handler.go} (98%) diff --git a/controller/channel-test.go b/controller/channel-test.go index d162d8cf..26c97056 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -90,7 +90,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr info := relaycommon.GenRelayInfo(c) - err = helper.ModelMappedHelper(c, info) + err = helper.ModelMappedHelper(c, info, nil) if err != nil { return err, nil } diff --git a/relay/relay-audio.go b/relay/audio_handler.go similarity index 96% rename from relay/relay-audio.go rename to relay/audio_handler.go index deb45c58..e55de042 100644 --- a/relay/relay-audio.go +++ b/relay/audio_handler.go @@ -55,7 +55,7 @@ func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto. } func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { - relayInfo := relaycommon.GenRelayInfo(c) + relayInfo := relaycommon.GenRelayInfoOpenAIAudio(c) audioRequest, err := getAndValidAudioRequest(c, relayInfo) if err != nil { @@ -89,13 +89,11 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { } }() - err = helper.ModelMappedHelper(c, relayInfo) + err = helper.ModelMappedHelper(c, relayInfo, audioRequest) if err != nil { return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) } - audioRequest.Model = relayInfo.UpstreamModelName - adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index a81eb3a9..968d9c9b 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -73,12 +73,12 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if model_setting.GetGeminiSettings().ThinkingAdapterEnabled { // 新增逻辑:处理 -thinking- 格式 - if strings.Contains(info.OriginModelName, "-thinking-") { + if strings.Contains(info.UpstreamModelName, "-thinking-") { parts := strings.Split(info.UpstreamModelName, "-thinking-") info.UpstreamModelName = parts[0] - } else if strings.HasSuffix(info.OriginModelName, "-thinking") { // 旧的适配 + } else if strings.HasSuffix(info.UpstreamModelName, "-thinking") { // 旧的适配 info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking") - } else if strings.HasSuffix(info.OriginModelName, "-nothinking") { + } else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") { info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking") } } diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index d4b7c209..ef2c35be 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -99,7 +99,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon } if model_setting.GetGeminiSettings().ThinkingAdapterEnabled { - modelName := info.OriginModelName + modelName := info.UpstreamModelName isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") && !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") && !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25") diff --git a/relay/claude_handler.go b/relay/claude_handler.go index 567378fb..42139ddf 100644 --- a/relay/claude_handler.go +++ b/relay/claude_handler.go @@ -46,13 +46,11 @@ func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) { relayInfo.IsStream = true } - err = helper.ModelMappedHelper(c, relayInfo) + err = helper.ModelMappedHelper(c, relayInfo, textRequest) if err != nil { return service.ClaudeErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) } - textRequest.Model = relayInfo.UpstreamModelName - promptTokens, err := getClaudePromptTokens(textRequest, relayInfo) // count messages token error 计算promptTokens错误 if err != nil { diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index a842a58d..3759c363 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -34,9 +34,14 @@ type ClaudeConvertInfo struct { } const ( - RelayFormatOpenAI = "openai" - RelayFormatClaude = "claude" - RelayFormatGemini = "gemini" + RelayFormatOpenAI = "openai" + RelayFormatClaude = "claude" + RelayFormatGemini = "gemini" + RelayFormatOpenAIResponses = "openai_responses" + RelayFormatOpenAIAudio = "openai_audio" + RelayFormatOpenAIImage = "openai_image" + RelayFormatRerank = "rerank" + RelayFormatEmbedding = "embedding" ) type RerankerInfo struct { @@ -143,6 +148,7 @@ func GenRelayInfoClaude(c *gin.Context) *RelayInfo { func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo { info := GenRelayInfo(c) info.RelayMode = relayconstant.RelayModeRerank + info.RelayFormat = RelayFormatRerank info.RerankerInfo = &RerankerInfo{ Documents: req.Documents, ReturnDocuments: req.GetReturnDocuments(), @@ -150,9 +156,25 @@ func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo { return info } +func GenRelayInfoOpenAIAudio(c *gin.Context) *RelayInfo { + info := GenRelayInfo(c) + info.RelayFormat = RelayFormatOpenAIAudio + return info +} + +func GenRelayInfoEmbedding(c *gin.Context) *RelayInfo { + info := GenRelayInfo(c) + info.RelayFormat = RelayFormatEmbedding + return info +} + func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *RelayInfo { info := GenRelayInfo(c) info.RelayMode = relayconstant.RelayModeResponses + info.RelayFormat = RelayFormatOpenAIResponses + + info.SupportStreamOptions = false + info.ResponsesUsageInfo = &ResponsesUsageInfo{ BuiltInTools: make(map[string]*BuildInToolInfo), } @@ -175,6 +197,19 @@ func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *Rel return info } +func GenRelayInfoGemini(c *gin.Context) *RelayInfo { + info := GenRelayInfo(c) + info.RelayFormat = RelayFormatGemini + info.ShouldIncludeUsage = false + return info +} + +func GenRelayInfoImage(c *gin.Context) *RelayInfo { + info := GenRelayInfo(c) + info.RelayFormat = RelayFormatOpenAIImage + return info +} + func GenRelayInfo(c *gin.Context) *RelayInfo { channelType := c.GetInt("channel_type") channelId := c.GetInt("channel_id") @@ -243,10 +278,6 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { if streamSupportedChannels[info.ChannelType] { info.SupportStreamOptions = true } - // responses 模式不支持 StreamOptions - if relayconstant.RelayModeResponses == info.RelayMode { - info.SupportStreamOptions = false - } return info } diff --git a/relay/relay_embedding.go b/relay/embedding_handler.go similarity index 96% rename from relay/relay_embedding.go rename to relay/embedding_handler.go index b4909849..fbf4990a 100644 --- a/relay/relay_embedding.go +++ b/relay/embedding_handler.go @@ -33,7 +33,7 @@ func validateEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, embed } func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { - relayInfo := relaycommon.GenRelayInfo(c) + relayInfo := relaycommon.GenRelayInfoEmbedding(c) var embeddingRequest *dto.EmbeddingRequest err := common.UnmarshalBodyReusable(c, &embeddingRequest) @@ -47,13 +47,11 @@ func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) return service.OpenAIErrorWrapperLocal(err, "invalid_embedding_request", http.StatusBadRequest) } - err = helper.ModelMappedHelper(c, relayInfo) + err = helper.ModelMappedHelper(c, relayInfo, embeddingRequest) if err != nil { return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) } - embeddingRequest.Model = relayInfo.UpstreamModelName - promptToken := getEmbeddingPromptToken(*embeddingRequest) relayInfo.PromptTokens = promptToken diff --git a/relay/relay-gemini.go b/relay/gemini_handler.go similarity index 98% rename from relay/relay-gemini.go rename to relay/gemini_handler.go index 455b31b7..fa41cc7b 100644 --- a/relay/relay-gemini.go +++ b/relay/gemini_handler.go @@ -83,7 +83,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { return service.OpenAIErrorWrapperLocal(err, "invalid_gemini_request", http.StatusBadRequest) } - relayInfo := relaycommon.GenRelayInfo(c) + relayInfo := relaycommon.GenRelayInfoGemini(c) // 检查 Gemini 流式模式 checkGeminiStreamMode(c, relayInfo) @@ -97,7 +97,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { } // model mapped 模型映射 - err = helper.ModelMappedHelper(c, relayInfo) + err = helper.ModelMappedHelper(c, relayInfo, req) if err != nil { return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest) } diff --git a/relay/helper/model_mapped.go b/relay/helper/model_mapped.go index 9bf67c03..c1735149 100644 --- a/relay/helper/model_mapped.go +++ b/relay/helper/model_mapped.go @@ -4,12 +4,14 @@ import ( "encoding/json" "errors" "fmt" + common2 "one-api/common" + "one-api/dto" "one-api/relay/common" "github.com/gin-gonic/gin" ) -func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error { +func ModelMappedHelper(c *gin.Context, info *common.RelayInfo, request any) error { // map model name modelMapping := c.GetString("model_mapping") if modelMapping != "" && modelMapping != "{}" { @@ -50,5 +52,41 @@ func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error { info.UpstreamModelName = currentModel } } + if request != nil { + switch info.RelayFormat { + case common.RelayFormatGemini: + // Gemini 模型映射 + case common.RelayFormatClaude: + if claudeRequest, ok := request.(*dto.ClaudeRequest); ok { + claudeRequest.Model = info.UpstreamModelName + } + case common.RelayFormatOpenAIResponses: + if openAIResponsesRequest, ok := request.(*dto.OpenAIResponsesRequest); ok { + openAIResponsesRequest.Model = info.UpstreamModelName + } + case common.RelayFormatOpenAIAudio: + if openAIAudioRequest, ok := request.(*dto.AudioRequest); ok { + openAIAudioRequest.Model = info.UpstreamModelName + } + case common.RelayFormatOpenAIImage: + if imageRequest, ok := request.(*dto.ImageRequest); ok { + imageRequest.Model = info.UpstreamModelName + } + case common.RelayFormatRerank: + if rerankRequest, ok := request.(*dto.RerankRequest); ok { + rerankRequest.Model = info.UpstreamModelName + } + case common.RelayFormatEmbedding: + if embeddingRequest, ok := request.(*dto.EmbeddingRequest); ok { + embeddingRequest.Model = info.UpstreamModelName + } + default: + if openAIRequest, ok := request.(*dto.GeneralOpenAIRequest); ok { + openAIRequest.Model = info.UpstreamModelName + } else { + common2.LogWarn(c, fmt.Sprintf("model mapped but request type %T not supported", request)) + } + } + } return nil } diff --git a/relay/relay-image.go b/relay/image_handler.go similarity index 98% rename from relay/relay-image.go rename to relay/image_handler.go index 197a8af6..57917025 100644 --- a/relay/relay-image.go +++ b/relay/image_handler.go @@ -102,7 +102,7 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto. } func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { - relayInfo := relaycommon.GenRelayInfo(c) + relayInfo := relaycommon.GenRelayInfoImage(c) imageRequest, err := getAndValidImageRequest(c, relayInfo) if err != nil { @@ -110,13 +110,11 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { return service.OpenAIErrorWrapper(err, "invalid_image_request", http.StatusBadRequest) } - err = helper.ModelMappedHelper(c, relayInfo) + err = helper.ModelMappedHelper(c, relayInfo, imageRequest) if err != nil { return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) } - imageRequest.Model = relayInfo.UpstreamModelName - priceData, err := helper.ModelPriceHelper(c, relayInfo, len(imageRequest.Prompt), 0) if err != nil { return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError) diff --git a/relay/relay-text.go b/relay/relay-text.go index 24fb8155..bf5a0259 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -108,13 +108,11 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { } } - err = helper.ModelMappedHelper(c, relayInfo) + err = helper.ModelMappedHelper(c, relayInfo, textRequest) if err != nil { return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) } - textRequest.Model = relayInfo.UpstreamModelName - // 获取 promptTokens,如果上下文中已经存在,则直接使用 var promptTokens int if value, exists := c.Get("prompt_tokens"); exists { diff --git a/relay/relay_rerank.go b/relay/rerank_handler.go similarity index 97% rename from relay/relay_rerank.go rename to relay/rerank_handler.go index 6ca98de7..4d02c84f 100644 --- a/relay/relay_rerank.go +++ b/relay/rerank_handler.go @@ -42,13 +42,11 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest) } - err = helper.ModelMappedHelper(c, relayInfo) + err = helper.ModelMappedHelper(c, relayInfo, rerankRequest) if err != nil { return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) } - rerankRequest.Model = relayInfo.UpstreamModelName - promptToken := getRerankPromptToken(*rerankRequest) relayInfo.PromptTokens = promptToken diff --git a/relay/relay-responses.go b/relay/responses_handler.go similarity index 98% rename from relay/relay-responses.go rename to relay/responses_handler.go index fd3ddb5a..8e8a3451 100644 --- a/relay/relay-responses.go +++ b/relay/responses_handler.go @@ -63,11 +63,11 @@ func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) } } - err = helper.ModelMappedHelper(c, relayInfo) + err = helper.ModelMappedHelper(c, relayInfo, req) if err != nil { return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest) } - req.Model = relayInfo.UpstreamModelName + if value, exists := c.Get("prompt_tokens"); exists { promptTokens := value.(int) relayInfo.SetPromptTokens(promptTokens) From 9344cab59a8aea1b4c44e5d55823485942089c73 Mon Sep 17 00:00:00 2001 From: RedwindA Date: Fri, 20 Jun 2025 16:40:51 +0800 Subject: [PATCH 13/19] fix: update model name logic for vertex --- relay/channel/vertex/adaptor.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go index 424234a8..e568f651 100644 --- a/relay/channel/vertex/adaptor.go +++ b/relay/channel/vertex/adaptor.go @@ -83,10 +83,13 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { suffix := "" if a.RequestMode == RequestModeGemini { if model_setting.GetGeminiSettings().ThinkingAdapterEnabled { - // suffix -thinking and -nothinking - if strings.HasSuffix(info.OriginModelName, "-thinking") { + // 新增逻辑:处理 -thinking- 格式 + if strings.Contains(info.UpstreamModelName, "-thinking-") { + parts := strings.Split(info.UpstreamModelName, "-thinking-") + info.UpstreamModelName = parts[0] + } else if strings.HasSuffix(info.UpstreamModelName, "-thinking") { // 旧的适配 info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking") - } else if strings.HasSuffix(info.OriginModelName, "-nothinking") { + } else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") { info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking") } } From a1a7ddbc8395d416979032a21401534567028824 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Fri, 20 Jun 2025 17:48:55 +0800 Subject: [PATCH 14/19] fix: update payment method handling in topup controller - Refactored payment method validation to check against available methods. - Changed payment method types from "zfb" to "alipay" and "wx" to "wxpay" for consistency. - Updated the purchase request to use the validated payment method directly. --- controller/topup.go | 14 ++++++-------- setting/payment.go | 4 ++-- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/controller/topup.go b/controller/topup.go index 951b2cf2..827dda39 100644 --- a/controller/topup.go +++ b/controller/topup.go @@ -97,14 +97,12 @@ func RequestEpay(c *gin.Context) { c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"}) return } - payType := "wxpay" - if req.PaymentMethod == "zfb" { - payType = "alipay" - } - if req.PaymentMethod == "wx" { - req.PaymentMethod = "wxpay" - payType = "wxpay" + + if !setting.ContainsPayMethod(req.PaymentMethod) { + c.JSON(200, gin.H{"message": "error", "data": "支付方式不存在"}) + return } + callBackAddress := service.GetCallbackAddress() returnUrl, _ := url.Parse(setting.ServerAddress + "/console/log") notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify") @@ -116,7 +114,7 @@ func RequestEpay(c *gin.Context) { return } uri, params, err := client.Purchase(&epay.PurchaseArgs{ - Type: payType, + Type: req.PaymentMethod, ServiceTradeNo: tradeNo, Name: fmt.Sprintf("TUC%d", req.Amount), Money: strconv.FormatFloat(payMoney, 'f', 2, 64), diff --git a/setting/payment.go b/setting/payment.go index 4ffa4381..3fc0f14a 100644 --- a/setting/payment.go +++ b/setting/payment.go @@ -13,12 +13,12 @@ var PayMethods = []map[string]string{ { "name": "支付宝", "color": "rgba(var(--semi-blue-5), 1)", - "type": "zfb", + "type": "alipay", }, { "name": "微信", "color": "rgba(var(--semi-green-5), 1)", - "type": "wx", + "type": "wxpay", }, } From f5e80af0b339cf40d96b4a0e9f440c9903be8c39 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Fri, 20 Jun 2025 21:55:28 +0800 Subject: [PATCH 15/19] fix: update response handling in GeminiTextGenerationStreamHandler - Changed response handling from ObjectData to StringData for improved data processing. - Ensured proper error logging in case of response handling failure. --- relay/channel/gemini/relay-gemini-native.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/relay/channel/gemini/relay-gemini-native.go b/relay/channel/gemini/relay-gemini-native.go index bf87eafd..cf7920dc 100644 --- a/relay/channel/gemini/relay-gemini-native.go +++ b/relay/channel/gemini/relay-gemini-native.go @@ -108,7 +108,7 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info } // 直接发送 GeminiChatResponse 响应 - err = helper.ObjectData(c, geminiResponse) + err = helper.StringData(c, data) if err != nil { common.LogError(c, err.Error()) } From a56d9ea98bb4d7adabd829b393e6585151060c91 Mon Sep 17 00:00:00 2001 From: creamlike1024 Date: Fri, 20 Jun 2025 23:01:10 +0800 Subject: [PATCH 16/19] =?UTF-8?q?fix:=20gemini=20=E5=8E=9F=E7=94=9F?= =?UTF-8?q?=E6=A0=BC=E5=BC=8F=E6=B5=81=E6=A8=A1=E5=BC=8F=E4=B8=AD=E6=96=AD?= =?UTF-8?q?=E8=AF=B7=E6=B1=82=E6=9C=AA=E8=AE=A1=E8=B4=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- relay/channel/gemini/relay-gemini-native.go | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/relay/channel/gemini/relay-gemini-native.go b/relay/channel/gemini/relay-gemini-native.go index cf7920dc..3a017a11 100644 --- a/relay/channel/gemini/relay-gemini-native.go +++ b/relay/channel/gemini/relay-gemini-native.go @@ -75,6 +75,9 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info helper.SetEventStreamHeaders(c) + // 本地统计的completion tokens + localCompletionTokens := 0 + helper.StreamScannerHandler(c, resp, info, func(data string) bool { var geminiResponse GeminiChatResponse err := common.DecodeJsonStr(data, &geminiResponse) @@ -89,6 +92,12 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info if part.InlineData != nil && part.InlineData.MimeType != "" { imageCount++ } + // 本地统计completion tokens + textTokens, err := service.CountTextToken(part.Text, info.UpstreamModelName) + if err != nil { + common.LogError(c, "error counting text token: "+err.Error()) + } + localCompletionTokens += textTokens } } @@ -122,6 +131,12 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info } } + // 如果usage.CompletionTokens为0,则使用本地统计的completion tokens + if usage.CompletionTokens == 0 { + usage.CompletionTokens = localCompletionTokens + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + } + // 计算最终使用量 // usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens From a9e5d99ea3ec6ed1591b7e4006e51de9a718c1ba Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Sat, 21 Jun 2025 00:54:40 +0800 Subject: [PATCH 17/19] refactor: token counter logic --- relay/audio_handler.go | 2 +- relay/channel/claude/relay-claude.go | 6 +-- relay/channel/cloudflare/relay_cloudflare.go | 6 +-- relay/channel/cohere/relay-cohere.go | 2 +- relay/channel/coze/relay-coze.go | 2 +- relay/channel/dify/relay-dify.go | 2 +- relay/channel/gemini/relay-gemini-native.go | 17 +++----- relay/channel/openai/relay-openai.go | 18 ++++---- relay/channel/openai/relay_responses.go | 2 +- relay/channel/palm/adaptor.go | 2 +- relay/channel/palm/relay-palm.go | 2 +- relay/channel/tencent/adaptor.go | 2 +- relay/channel/xai/text.go | 2 +- relay/embedding_handler.go | 2 +- relay/gemini_handler.go | 8 ++-- relay/relay-text.go | 6 +-- relay/rerank_handler.go | 8 ++-- relay/responses_handler.go | 8 ++-- service/token_counter.go | 44 ++++++++------------ service/usage_helpr.go | 6 +-- 20 files changed, 64 insertions(+), 83 deletions(-) diff --git a/relay/audio_handler.go b/relay/audio_handler.go index e55de042..96cf1019 100644 --- a/relay/audio_handler.go +++ b/relay/audio_handler.go @@ -66,7 +66,7 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { promptTokens := 0 preConsumedTokens := common.PreConsumedQuota if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech { - promptTokens, err = service.CountTTSToken(audioRequest.Input, audioRequest.Model) + promptTokens = service.CountTTSToken(audioRequest.Input, audioRequest.Model) if err != nil { return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError) } diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index ba20adea..5e15d3a2 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -549,7 +549,7 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) { if requestMode == RequestModeCompletion { - claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens) + claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens) } else { if claudeInfo.Usage.PromptTokens == 0 { //上游出错 @@ -558,7 +558,7 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau if common.DebugEnabled { common.SysError("claude response usage is not complete, maybe upstream error") } - claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens) + claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens) } } @@ -618,7 +618,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud } } if requestMode == RequestModeCompletion { - completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName) + completionTokens := service.CountTextToken(claudeResponse.Completion, info.OriginModelName) if err != nil { return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError) } diff --git a/relay/channel/cloudflare/relay_cloudflare.go b/relay/channel/cloudflare/relay_cloudflare.go index a487429c..50d4928a 100644 --- a/relay/channel/cloudflare/relay_cloudflare.go +++ b/relay/channel/cloudflare/relay_cloudflare.go @@ -71,7 +71,7 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela if err := scanner.Err(); err != nil { common.LogError(c, "error_scanning_stream_response: "+err.Error()) } - usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) if info.ShouldIncludeUsage { response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage) err := helper.ObjectData(c, response) @@ -108,7 +108,7 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) for _, choice := range response.Choices { responseText += choice.Message.StringContent() } - usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) response.Usage = *usage response.Id = helper.GetResponseID(c) jsonResponse, err := json.Marshal(response) @@ -150,7 +150,7 @@ func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayIn usage := &dto.Usage{} usage.PromptTokens = info.PromptTokens - usage.CompletionTokens, _ = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName) + usage.CompletionTokens = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName) usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens return nil, usage diff --git a/relay/channel/cohere/relay-cohere.go b/relay/channel/cohere/relay-cohere.go index 8a044bf2..29064242 100644 --- a/relay/channel/cohere/relay-cohere.go +++ b/relay/channel/cohere/relay-cohere.go @@ -162,7 +162,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. } }) if usage.PromptTokens == 0 { - usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } return nil, usage } diff --git a/relay/channel/coze/relay-coze.go b/relay/channel/coze/relay-coze.go index 6db40213..e9719cb9 100644 --- a/relay/channel/coze/relay-coze.go +++ b/relay/channel/coze/relay-coze.go @@ -144,7 +144,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo if usage.TotalTokens == 0 { usage.PromptTokens = info.PromptTokens - usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText) + usage.CompletionTokens = service.CountTextToken("gpt-3.5-turbo", responseText) usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens } diff --git a/relay/channel/dify/relay-dify.go b/relay/channel/dify/relay-dify.go index 93e3e8d6..b3ae5927 100644 --- a/relay/channel/dify/relay-dify.go +++ b/relay/channel/dify/relay-dify.go @@ -250,7 +250,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re } if usage.TotalTokens == 0 { usage.PromptTokens = info.PromptTokens - usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText) + usage.CompletionTokens = service.CountTextToken("gpt-3.5-turbo", responseText) usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens } usage.CompletionTokens += nodeToken diff --git a/relay/channel/gemini/relay-gemini-native.go b/relay/channel/gemini/relay-gemini-native.go index 3a017a11..1a497b9f 100644 --- a/relay/channel/gemini/relay-gemini-native.go +++ b/relay/channel/gemini/relay-gemini-native.go @@ -9,6 +9,7 @@ import ( relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" + "strings" "github.com/gin-gonic/gin" ) @@ -75,8 +76,7 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info helper.SetEventStreamHeaders(c) - // 本地统计的completion tokens - localCompletionTokens := 0 + responseText := strings.Builder{} helper.StreamScannerHandler(c, resp, info, func(data string) bool { var geminiResponse GeminiChatResponse @@ -92,12 +92,9 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info if part.InlineData != nil && part.InlineData.MimeType != "" { imageCount++ } - // 本地统计completion tokens - textTokens, err := service.CountTextToken(part.Text, info.UpstreamModelName) - if err != nil { - common.LogError(c, "error counting text token: "+err.Error()) + if part.Text != "" { + responseText.WriteString(part.Text) } - localCompletionTokens += textTokens } } @@ -133,13 +130,9 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info // 如果usage.CompletionTokens为0,则使用本地统计的completion tokens if usage.CompletionTokens == 0 { - usage.CompletionTokens = localCompletionTokens - usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens) } - // 计算最终使用量 - // usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens - // 移除流式响应结尾的[Done],因为Gemini API没有发送Done的行为 //helper.Done(c) diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 4dc0fc60..71590cd6 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -8,7 +8,6 @@ import ( "math" "mime/multipart" "net/http" - "path/filepath" "one-api/common" "one-api/constant" "one-api/dto" @@ -16,6 +15,7 @@ import ( "one-api/relay/helper" "one-api/service" "os" + "path/filepath" "strings" "github.com/bytedance/gopkg/util/gopool" @@ -181,7 +181,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel } if !containStreamUsage { - usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens) + usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens) usage.CompletionTokens += toolCount * 7 } else { if info.ChannelType == common.ChannelTypeDeepSeek { @@ -216,7 +216,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI StatusCode: resp.StatusCode, }, nil } - + forceFormat := false if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok { forceFormat = forceFmt @@ -225,7 +225,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) { completionTokens := 0 for _, choice := range simpleResponse.Choices { - ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName) + ctkm := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName) completionTokens += ctkm } simpleResponse.Usage = dto.Usage{ @@ -276,9 +276,9 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { // the status code has been judged before, if there is a body reading failure, // it should be regarded as a non-recoverable error, so it should not return err for external retry. - // Analogous to nginx's load balancing, it will only retry if it can't be requested or - // if the upstream returns a specific status code, once the upstream has already written the header, - // the subsequent failure of the response body should be regarded as a non-recoverable error, + // Analogous to nginx's load balancing, it will only retry if it can't be requested or + // if the upstream returns a specific status code, once the upstream has already written the header, + // the subsequent failure of the response body should be regarded as a non-recoverable error, // and can be terminated directly. defer resp.Body.Close() usage := &dto.Usage{} @@ -346,12 +346,12 @@ func countAudioTokens(c *gin.Context) (int, error) { if err = c.ShouldBind(&reqBody); err != nil { return 0, errors.WithStack(err) } - ext := filepath.Ext(reqBody.File.Filename) // 获取文件扩展名 + ext := filepath.Ext(reqBody.File.Filename) // 获取文件扩展名 reqFp, err := reqBody.File.Open() if err != nil { return 0, errors.WithStack(err) } - defer reqFp.Close() + defer reqFp.Close() tmpFp, err := os.CreateTemp("", "audio-*"+ext) if err != nil { diff --git a/relay/channel/openai/relay_responses.go b/relay/channel/openai/relay_responses.go index 1d1e060e..da9382c3 100644 --- a/relay/channel/openai/relay_responses.go +++ b/relay/channel/openai/relay_responses.go @@ -110,7 +110,7 @@ func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relayc tempStr := responseTextBuilder.String() if len(tempStr) > 0 { // 非正常结束,使用输出文本的 token 数量 - completionTokens, _ := service.CountTextToken(tempStr, info.UpstreamModelName) + completionTokens := service.CountTextToken(tempStr, info.UpstreamModelName) usage.CompletionTokens = completionTokens } } diff --git a/relay/channel/palm/adaptor.go b/relay/channel/palm/adaptor.go index 3a06e7ee..aee4a307 100644 --- a/relay/channel/palm/adaptor.go +++ b/relay/channel/palm/adaptor.go @@ -74,7 +74,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.IsStream { var responseText string err, responseText = palmStreamHandler(c, resp) - usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { err, usage = palmHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } diff --git a/relay/channel/palm/relay-palm.go b/relay/channel/palm/relay-palm.go index 0c6f8641..9d3dbd67 100644 --- a/relay/channel/palm/relay-palm.go +++ b/relay/channel/palm/relay-palm.go @@ -155,7 +155,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st }, nil } fullTextResponse := responsePaLM2OpenAI(&palmResponse) - completionTokens, _ := service.CountTextToken(palmResponse.Candidates[0].Content, model) + completionTokens := service.CountTextToken(palmResponse.Candidates[0].Content, model) usage := dto.Usage{ PromptTokens: promptTokens, CompletionTokens: completionTokens, diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go index 44718a25..7ea3aae7 100644 --- a/relay/channel/tencent/adaptor.go +++ b/relay/channel/tencent/adaptor.go @@ -98,7 +98,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.IsStream { var responseText string err, responseText = tencentStreamHandler(c, resp) - usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { err, usage = tencentHandler(c, resp) } diff --git a/relay/channel/xai/text.go b/relay/channel/xai/text.go index e019c2dc..408160fb 100644 --- a/relay/channel/xai/text.go +++ b/relay/channel/xai/text.go @@ -68,7 +68,7 @@ func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel }) if !containStreamUsage { - usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens) + usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens) usage.CompletionTokens += toolCount * 7 } diff --git a/relay/embedding_handler.go b/relay/embedding_handler.go index fbf4990a..849c70da 100644 --- a/relay/embedding_handler.go +++ b/relay/embedding_handler.go @@ -15,7 +15,7 @@ import ( ) func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int { - token, _ := service.CountTokenInput(embeddingRequest.Input, embeddingRequest.Model) + token := service.CountTokenInput(embeddingRequest.Input, embeddingRequest.Model) return token } diff --git a/relay/gemini_handler.go b/relay/gemini_handler.go index fa41cc7b..14d58cc5 100644 --- a/relay/gemini_handler.go +++ b/relay/gemini_handler.go @@ -59,7 +59,7 @@ func checkGeminiInputSensitive(textRequest *gemini.GeminiChatRequest) ([]string, return sensitiveWords, err } -func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) (int, error) { +func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) int { // 计算输入 token 数量 var inputTexts []string for _, content := range req.Contents { @@ -71,9 +71,9 @@ func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.Relay } inputText := strings.Join(inputTexts, "\n") - inputTokens, err := service.CountTokenInput(inputText, info.UpstreamModelName) + inputTokens := service.CountTokenInput(inputText, info.UpstreamModelName) info.PromptTokens = inputTokens - return inputTokens, err + return inputTokens } func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { @@ -106,7 +106,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { promptTokens := value.(int) relayInfo.SetPromptTokens(promptTokens) } else { - promptTokens, err := getGeminiInputTokens(req, relayInfo) + promptTokens := getGeminiInputTokens(req, relayInfo) if err != nil { return service.OpenAIErrorWrapperLocal(err, "count_input_tokens_error", http.StatusBadRequest) } diff --git a/relay/relay-text.go b/relay/relay-text.go index bf5a0259..db8d0d3b 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -251,11 +251,11 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re case relayconstant.RelayModeChatCompletions: promptTokens, err = service.CountTokenChatRequest(info, *textRequest) case relayconstant.RelayModeCompletions: - promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model) + promptTokens = service.CountTokenInput(textRequest.Prompt, textRequest.Model) case relayconstant.RelayModeModerations: - promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model) + promptTokens = service.CountTokenInput(textRequest.Input, textRequest.Model) case relayconstant.RelayModeEmbeddings: - promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model) + promptTokens = service.CountTokenInput(textRequest.Input, textRequest.Model) default: err = errors.New("unknown relay mode") promptTokens = 0 diff --git a/relay/rerank_handler.go b/relay/rerank_handler.go index 4d02c84f..319811b8 100644 --- a/relay/rerank_handler.go +++ b/relay/rerank_handler.go @@ -14,12 +14,10 @@ import ( ) func getRerankPromptToken(rerankRequest dto.RerankRequest) int { - token, _ := service.CountTokenInput(rerankRequest.Query, rerankRequest.Model) + token := service.CountTokenInput(rerankRequest.Query, rerankRequest.Model) for _, document := range rerankRequest.Documents { - tkm, err := service.CountTokenInput(document, rerankRequest.Model) - if err == nil { - token += tkm - } + tkm := service.CountTokenInput(document, rerankRequest.Model) + token += tkm } return token } diff --git a/relay/responses_handler.go b/relay/responses_handler.go index 8e8a3451..9d4adf49 100644 --- a/relay/responses_handler.go +++ b/relay/responses_handler.go @@ -40,10 +40,10 @@ func checkInputSensitive(textRequest *dto.OpenAIResponsesRequest, info *relaycom return sensitiveWords, err } -func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) (int, error) { - inputTokens, err := service.CountTokenInput(req.Input, req.Model) +func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) int { + inputTokens := service.CountTokenInput(req.Input, req.Model) info.PromptTokens = inputTokens - return inputTokens, err + return inputTokens } func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { @@ -72,7 +72,7 @@ func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) promptTokens := value.(int) relayInfo.SetPromptTokens(promptTokens) } else { - promptTokens, err := getInputTokens(req, relayInfo) + promptTokens := getInputTokens(req, relayInfo) if err != nil { return service.OpenAIErrorWrapperLocal(err, "count_input_tokens_error", http.StatusBadRequest) } diff --git a/service/token_counter.go b/service/token_counter.go index 82de0a05..53c6c2fa 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -171,7 +171,7 @@ func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenA countStr += fmt.Sprintf("%v", tool.Function.Parameters) } } - toolTokens, err := CountTokenInput(countStr, request.Model) + toolTokens := CountTokenInput(countStr, request.Model) if err != nil { return 0, err } @@ -194,7 +194,7 @@ func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, erro // Count tokens in system message if request.System != "" { - systemTokens, err := CountTokenInput(request.System, model) + systemTokens := CountTokenInput(request.System, model) if err != nil { return 0, err } @@ -296,10 +296,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, switch request.Type { case dto.RealtimeEventTypeSessionUpdate: if request.Session != nil { - msgTokens, err := CountTextToken(request.Session.Instructions, model) - if err != nil { - return 0, 0, err - } + msgTokens := CountTextToken(request.Session.Instructions, model) textToken += msgTokens } case dto.RealtimeEventResponseAudioDelta: @@ -311,10 +308,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, audioToken += atk case dto.RealtimeEventResponseAudioTranscriptionDelta, dto.RealtimeEventResponseFunctionCallArgumentsDelta: // count text token - tkm, err := CountTextToken(request.Delta, model) - if err != nil { - return 0, 0, fmt.Errorf("error counting text token: %v", err) - } + tkm := CountTextToken(request.Delta, model) textToken += tkm case dto.RealtimeEventInputAudioBufferAppend: // count audio token @@ -329,10 +323,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, case "message": for _, content := range request.Item.Content { if content.Type == "input_text" { - tokens, err := CountTextToken(content.Text, model) - if err != nil { - return 0, 0, err - } + tokens := CountTextToken(content.Text, model) textToken += tokens } } @@ -343,10 +334,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, if !info.IsFirstRequest { if info.RealtimeTools != nil && len(info.RealtimeTools) > 0 { for _, tool := range info.RealtimeTools { - toolTokens, err := CountTokenInput(tool, model) - if err != nil { - return 0, 0, err - } + toolTokens := CountTokenInput(tool, model) textToken += 8 textToken += toolTokens } @@ -409,7 +397,7 @@ func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, mod return tokenNum, nil } -func CountTokenInput(input any, model string) (int, error) { +func CountTokenInput(input any, model string) int { switch v := input.(type) { case string: return CountTextToken(v, model) @@ -432,13 +420,13 @@ func CountTokenInput(input any, model string) (int, error) { func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, model string) int { tokens := 0 for _, message := range messages { - tkm, _ := CountTokenInput(message.Delta.GetContentString(), model) + tkm := CountTokenInput(message.Delta.GetContentString(), model) tokens += tkm if message.Delta.ToolCalls != nil { for _, tool := range message.Delta.ToolCalls { - tkm, _ := CountTokenInput(tool.Function.Name, model) + tkm := CountTokenInput(tool.Function.Name, model) tokens += tkm - tkm, _ = CountTokenInput(tool.Function.Arguments, model) + tkm = CountTokenInput(tool.Function.Arguments, model) tokens += tkm } } @@ -446,9 +434,9 @@ func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, return tokens } -func CountTTSToken(text string, model string) (int, error) { +func CountTTSToken(text string, model string) int { if strings.HasPrefix(model, "tts") { - return utf8.RuneCountInString(text), nil + return utf8.RuneCountInString(text) } else { return CountTextToken(text, model) } @@ -483,8 +471,10 @@ func CountAudioTokenOutput(audioBase64 string, audioFormat string) (int, error) //} // CountTextToken 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量 -func CountTextToken(text string, model string) (int, error) { - var err error +func CountTextToken(text string, model string) int { + if text == "" { + return 0 + } tokenEncoder := getTokenEncoder(model) - return getTokenNum(tokenEncoder, text), err + return getTokenNum(tokenEncoder, text) } diff --git a/service/usage_helpr.go b/service/usage_helpr.go index c52e1e15..ca9c0830 100644 --- a/service/usage_helpr.go +++ b/service/usage_helpr.go @@ -16,13 +16,13 @@ import ( // return 0, errors.New("unknown relay mode") //} -func ResponseText2Usage(responseText string, modeName string, promptTokens int) (*dto.Usage, error) { +func ResponseText2Usage(responseText string, modeName string, promptTokens int) *dto.Usage { usage := &dto.Usage{} usage.PromptTokens = promptTokens - ctkm, err := CountTextToken(responseText, modeName) + ctkm := CountTextToken(responseText, modeName) usage.CompletionTokens = ctkm usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens - return usage, err + return usage } func ValidUsage(usage *dto.Usage) bool { From 0708452939d610a73fd8530f6b8ddf48c2cf2ff2 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Sat, 21 Jun 2025 01:08:15 +0800 Subject: [PATCH 18/19] fix: improve usage calculation in GeminiTextGenerationStreamHandler --- relay/channel/gemini/relay-gemini-native.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/relay/channel/gemini/relay-gemini-native.go b/relay/channel/gemini/relay-gemini-native.go index 1a497b9f..39757cef 100644 --- a/relay/channel/gemini/relay-gemini-native.go +++ b/relay/channel/gemini/relay-gemini-native.go @@ -130,7 +130,13 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info // 如果usage.CompletionTokens为0,则使用本地统计的completion tokens if usage.CompletionTokens == 0 { - usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens) + str := responseText.String() + if len(str) > 0 { + usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens) + } else { + // 空补全,不需要使用量 + usage = &dto.Usage{} + } } // 移除流式响应结尾的[Done],因为Gemini API没有发送Done的行为 From 7afd3f97eec60111f18d231dcc9f9a6bc20045f5 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Sat, 21 Jun 2025 01:16:54 +0800 Subject: [PATCH 19/19] fix: remove unnecessary error handling in token counting functions --- relay/audio_handler.go | 3 --- relay/channel/claude/relay-claude.go | 3 --- relay/channel/coze/relay-coze.go | 12 +++++------- relay/channel/dify/relay-dify.go | 9 +-------- relay/responses_handler.go | 3 --- 5 files changed, 6 insertions(+), 24 deletions(-) diff --git a/relay/audio_handler.go b/relay/audio_handler.go index 96cf1019..c1ce1a02 100644 --- a/relay/audio_handler.go +++ b/relay/audio_handler.go @@ -67,9 +67,6 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { preConsumedTokens := common.PreConsumedQuota if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech { promptTokens = service.CountTTSToken(audioRequest.Input, audioRequest.Model) - if err != nil { - return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError) - } preConsumedTokens = promptTokens relayInfo.PromptTokens = promptTokens } diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 5e15d3a2..406ebc8a 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -619,9 +619,6 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud } if requestMode == RequestModeCompletion { completionTokens := service.CountTextToken(claudeResponse.Completion, info.OriginModelName) - if err != nil { - return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError) - } claudeInfo.Usage.PromptTokens = info.PromptTokens claudeInfo.Usage.CompletionTokens = completionTokens claudeInfo.Usage.TotalTokens = info.PromptTokens + completionTokens diff --git a/relay/channel/coze/relay-coze.go b/relay/channel/coze/relay-coze.go index e9719cb9..ac76476f 100644 --- a/relay/channel/coze/relay-coze.go +++ b/relay/channel/coze/relay-coze.go @@ -106,7 +106,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo var currentEvent string var currentData string - var usage dto.Usage + var usage = &dto.Usage{} for scanner.Scan() { line := scanner.Text() @@ -114,7 +114,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo if line == "" { if currentEvent != "" && currentData != "" { // handle last event - handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info) + handleCozeEvent(c, currentEvent, currentData, &responseText, usage, id, info) currentEvent = "" currentData = "" } @@ -134,7 +134,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo // Last event if currentEvent != "" && currentData != "" { - handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info) + handleCozeEvent(c, currentEvent, currentData, &responseText, usage, id, info) } if err := scanner.Err(); err != nil { @@ -143,12 +143,10 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo helper.Done(c) if usage.TotalTokens == 0 { - usage.PromptTokens = info.PromptTokens - usage.CompletionTokens = service.CountTextToken("gpt-3.5-turbo", responseText) - usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, c.GetInt("coze_input_count")) } - return nil, &usage + return nil, usage } func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) { diff --git a/relay/channel/dify/relay-dify.go b/relay/channel/dify/relay-dify.go index b3ae5927..115aed1b 100644 --- a/relay/channel/dify/relay-dify.go +++ b/relay/channel/dify/relay-dify.go @@ -243,15 +243,8 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re return true }) helper.Done(c) - err := resp.Body.Close() - if err != nil { - // return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - common.SysError("close_response_body_failed: " + err.Error()) - } if usage.TotalTokens == 0 { - usage.PromptTokens = info.PromptTokens - usage.CompletionTokens = service.CountTextToken("gpt-3.5-turbo", responseText) - usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } usage.CompletionTokens += nodeToken return nil, usage diff --git a/relay/responses_handler.go b/relay/responses_handler.go index 9d4adf49..e744e354 100644 --- a/relay/responses_handler.go +++ b/relay/responses_handler.go @@ -73,9 +73,6 @@ func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) relayInfo.SetPromptTokens(promptTokens) } else { promptTokens := getInputTokens(req, relayInfo) - if err != nil { - return service.OpenAIErrorWrapperLocal(err, "count_input_tokens_error", http.StatusBadRequest) - } c.Set("prompt_tokens", promptTokens) }