From 6f39c0285706c0eda5836e2916a3c097c060be36 Mon Sep 17 00:00:00 2001 From: CaIon Date: Sat, 21 Feb 2026 23:47:55 +0800 Subject: [PATCH] refactor(relay): improve channel locking and retry logic in RelayTask - Enhanced the RelayTask function to utilize a locked channel when available, allowing for better reuse during retries. - Updated error handling to ensure proper context setup for the selected channel. - Clarified comments in ResolveOriginTask regarding channel locking and retry behavior. - Introduced a new field in TaskRelayInfo to store the locked channel object, improving type safety and reducing import cycles. --- controller/relay.go | 23 ++++++++++++++++++----- relay/common/relay_info.go | 5 +++++ relay/relay_task.go | 26 +++++++++++++------------- 3 files changed, 36 insertions(+), 18 deletions(-) diff --git a/controller/relay.go b/controller/relay.go index 1477df8f..6951974c 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -497,11 +497,24 @@ func RelayTask(c *gin.Context) { } for ; retryParam.GetRetry() <= common.RetryTimes; retryParam.IncreaseRetry() { - channel, channelErr := getChannel(c, relayInfo, retryParam) - if channelErr != nil { - logger.LogError(c, channelErr.Error()) - taskErr = service.TaskErrorWrapperLocal(channelErr.Err, "get_channel_failed", http.StatusInternalServerError) - break + var channel *model.Channel + + if lockedCh, ok := relayInfo.LockedChannel.(*model.Channel); ok && lockedCh != nil { + channel = lockedCh + if retryParam.GetRetry() > 0 { + if setupErr := middleware.SetupContextForSelectedChannel(c, channel, relayInfo.OriginModelName); setupErr != nil { + taskErr = service.TaskErrorWrapperLocal(setupErr.Err, "setup_locked_channel_failed", http.StatusInternalServerError) + break + } + } + } else { + var channelErr *types.NewAPIError + channel, channelErr = getChannel(c, relayInfo, retryParam) + if channelErr != nil { + logger.LogError(c, channelErr.Error()) + taskErr = service.TaskErrorWrapperLocal(channelErr.Err, "get_channel_failed", http.StatusInternalServerError) + break + } } addUsedChannel(c, channel.Id) diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index b6882681..541f1b9f 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -619,6 +619,11 @@ type TaskRelayInfo struct { PublicTaskID string ConsumeQuota bool + + // LockedChannel holds the full channel object when the request is bound to + // a specific channel (e.g., remix on origin task's channel). Stored as any + // to avoid an import cycle with model; callers type-assert to *model.Channel. + LockedChannel any } type TaskSubmitReq struct { diff --git a/relay/relay_task.go b/relay/relay_task.go index cc4d0e45..8d0e61d7 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -32,8 +32,9 @@ type TaskSubmitResult struct { } // ResolveOriginTask 处理基于已有任务的提交(remix / continuation): -// 查找原始任务、从中提取模型名称、将渠道锁定到原始任务的渠道(并通过 -// specific_channel_id 禁止重试),以及提取 OtherRatios(时长、分辨率)。 +// 查找原始任务、从中提取模型名称、将渠道锁定到原始任务的渠道 +// (通过 info.LockedChannel,重试时复用同一渠道并轮换 key), +// 以及提取 OtherRatios(时长、分辨率)。 // 该函数在控制器的重试循环之前调用一次,其结果通过 info 字段和上下文持久化。 func ResolveOriginTask(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError { // 检测 remix action @@ -77,15 +78,17 @@ func ResolveOriginTask(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskErr } } - // 锁定到原始任务的渠道(如果与当前选中的不同) + // 锁定到原始任务的渠道(重试时复用同一渠道,轮换 key) + ch, err := model.GetChannelById(originTask.ChannelId, true) + if err != nil { + return service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest) + } + if ch.Status != common.ChannelStatusEnabled { + return service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest) + } + info.LockedChannel = ch + if originTask.ChannelId != info.ChannelId { - ch, err := model.GetChannelById(originTask.ChannelId, true) - if err != nil { - return service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest) - } - if ch.Status != common.ChannelStatusEnabled { - return service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest) - } key, _, newAPIError := ch.GetNextEnabledKey() if newAPIError != nil { return service.TaskErrorWrapper(newAPIError, "channel_no_available_key", newAPIError.StatusCode) @@ -101,9 +104,6 @@ func ResolveOriginTask(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskErr info.ApiKey = key } - // 渠道已锁定到原始任务 → 禁止重试切换到其他渠道 - c.Set("specific_channel_id", fmt.Sprintf("%d", originTask.ChannelId)) - // 提取 remix 参数(时长、分辨率 → OtherRatios) if info.Action == constant.TaskActionRemix { if originTask.PrivateData.BillingContext != nil {