diff --git a/backend/internal/repository/concurrency_cache.go b/backend/internal/repository/concurrency_cache.go index 28932cc5..974ad0f8 100644 --- a/backend/internal/repository/concurrency_cache.go +++ b/backend/internal/repository/concurrency_cache.go @@ -407,29 +407,53 @@ func (c *concurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts [] return map[int64]*service.AccountLoadInfo{}, nil } - args := []any{c.slotTTLSeconds} - for _, acc := range accounts { - args = append(args, acc.ID, acc.MaxConcurrency) - } - - result, err := getAccountsLoadBatchScript.Run(ctx, c.rdb, []string{}, args...).Slice() + // 使用 Pipeline 替代 Lua 脚本,兼容 Redis Cluster(Lua 内动态拼 key 会 CROSSSLOT)。 + // 每个账号执行 3 个命令:ZREMRANGEBYSCORE(清理过期)、ZCARD(并发数)、GET(等待数)。 + now, err := c.rdb.Time(ctx).Result() if err != nil { - return nil, err + return nil, fmt.Errorf("redis TIME: %w", err) + } + cutoffTime := now.Unix() - int64(c.slotTTLSeconds) + + pipe := c.rdb.Pipeline() + + type accountCmds struct { + id int64 + maxConcurrency int + zcardCmd *redis.IntCmd + getCmd *redis.StringCmd + } + cmds := make([]accountCmds, 0, len(accounts)) + for _, acc := range accounts { + slotKey := accountSlotKeyPrefix + strconv.FormatInt(acc.ID, 10) + waitKey := accountWaitKeyPrefix + strconv.FormatInt(acc.ID, 10) + pipe.ZRemRangeByScore(ctx, slotKey, "-inf", strconv.FormatInt(cutoffTime, 10)) + ac := accountCmds{ + id: acc.ID, + maxConcurrency: acc.MaxConcurrency, + zcardCmd: pipe.ZCard(ctx, slotKey), + getCmd: pipe.Get(ctx, waitKey), + } + cmds = append(cmds, ac) } - loadMap := make(map[int64]*service.AccountLoadInfo) - for i := 0; i < len(result); i += 4 { - if i+3 >= len(result) { - break + if _, err := pipe.Exec(ctx); err != nil && !errors.Is(err, redis.Nil) { + return nil, fmt.Errorf("pipeline exec: %w", err) + } + + loadMap := make(map[int64]*service.AccountLoadInfo, len(accounts)) + for _, ac := range cmds { + currentConcurrency := int(ac.zcardCmd.Val()) + waitingCount := 0 + if v, err := ac.getCmd.Int(); err == nil { + waitingCount = v } - - accountID, _ := strconv.ParseInt(fmt.Sprintf("%v", result[i]), 10, 64) - currentConcurrency, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+1])) - waitingCount, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+2])) - loadRate, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+3])) - - loadMap[accountID] = &service.AccountLoadInfo{ - AccountID: accountID, + loadRate := 0 + if ac.maxConcurrency > 0 { + loadRate = (currentConcurrency + waitingCount) * 100 / ac.maxConcurrency + } + loadMap[ac.id] = &service.AccountLoadInfo{ + AccountID: ac.id, CurrentConcurrency: currentConcurrency, WaitingCount: waitingCount, LoadRate: loadRate, @@ -444,29 +468,52 @@ func (c *concurrencyCache) GetUsersLoadBatch(ctx context.Context, users []servic return map[int64]*service.UserLoadInfo{}, nil } - args := []any{c.slotTTLSeconds} - for _, u := range users { - args = append(args, u.ID, u.MaxConcurrency) - } - - result, err := getUsersLoadBatchScript.Run(ctx, c.rdb, []string{}, args...).Slice() + // 使用 Pipeline 替代 Lua 脚本,兼容 Redis Cluster。 + now, err := c.rdb.Time(ctx).Result() if err != nil { - return nil, err + return nil, fmt.Errorf("redis TIME: %w", err) + } + cutoffTime := now.Unix() - int64(c.slotTTLSeconds) + + pipe := c.rdb.Pipeline() + + type userCmds struct { + id int64 + maxConcurrency int + zcardCmd *redis.IntCmd + getCmd *redis.StringCmd + } + cmds := make([]userCmds, 0, len(users)) + for _, u := range users { + slotKey := userSlotKeyPrefix + strconv.FormatInt(u.ID, 10) + waitKey := waitQueueKeyPrefix + strconv.FormatInt(u.ID, 10) + pipe.ZRemRangeByScore(ctx, slotKey, "-inf", strconv.FormatInt(cutoffTime, 10)) + uc := userCmds{ + id: u.ID, + maxConcurrency: u.MaxConcurrency, + zcardCmd: pipe.ZCard(ctx, slotKey), + getCmd: pipe.Get(ctx, waitKey), + } + cmds = append(cmds, uc) } - loadMap := make(map[int64]*service.UserLoadInfo) - for i := 0; i < len(result); i += 4 { - if i+3 >= len(result) { - break + if _, err := pipe.Exec(ctx); err != nil && !errors.Is(err, redis.Nil) { + return nil, fmt.Errorf("pipeline exec: %w", err) + } + + loadMap := make(map[int64]*service.UserLoadInfo, len(users)) + for _, uc := range cmds { + currentConcurrency := int(uc.zcardCmd.Val()) + waitingCount := 0 + if v, err := uc.getCmd.Int(); err == nil { + waitingCount = v } - - userID, _ := strconv.ParseInt(fmt.Sprintf("%v", result[i]), 10, 64) - currentConcurrency, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+1])) - waitingCount, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+2])) - loadRate, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+3])) - - loadMap[userID] = &service.UserLoadInfo{ - UserID: userID, + loadRate := 0 + if uc.maxConcurrency > 0 { + loadRate = (currentConcurrency + waitingCount) * 100 / uc.maxConcurrency + } + loadMap[uc.id] = &service.UserLoadInfo{ + UserID: uc.id, CurrentConcurrency: currentConcurrency, WaitingCount: waitingCount, LoadRate: loadRate, diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go index 5ff2c866..e6660399 100644 --- a/backend/internal/service/billing_service.go +++ b/backend/internal/service/billing_service.go @@ -297,7 +297,7 @@ func (s *BillingService) CalculateCostWithLongContext(model string, tokens Usage } outRangeCost, err := s.CalculateCost(model, outRangeTokens, rateMultiplier*extraMultiplier) if err != nil { - return inRangeCost, nil // 出错时返回范围内成本 + return inRangeCost, fmt.Errorf("out-range cost: %w", err) } // 合并成本 diff --git a/backend/internal/service/sora_gateway_service.go b/backend/internal/service/sora_gateway_service.go index 68ebd90a..d7ff297c 100644 --- a/backend/internal/service/sora_gateway_service.go +++ b/backend/internal/service/sora_gateway_service.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "io" + "log" "mime" "net" "net/http" @@ -210,9 +211,11 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun if len(mediaURLs) > 0 && s.mediaStorage != nil && s.mediaStorage.Enabled() { stored, storeErr := s.mediaStorage.StoreFromURLs(reqCtx, mediaType, mediaURLs) if storeErr != nil { - return nil, s.handleSoraRequestError(ctx, account, storeErr, reqModel, c, clientStream) + // 存储失败时降级使用原始 URL,不中断用户请求 + log.Printf("[Sora] StoreFromURLs failed, falling back to original URLs: %v", storeErr) + } else { + finalURLs = s.normalizeSoraMediaURLs(stored) } - finalURLs = s.normalizeSoraMediaURLs(stored) } content := buildSoraContent(mediaType, finalURLs) diff --git a/backend/internal/service/sora_media_cleanup_service.go b/backend/internal/service/sora_media_cleanup_service.go index 7de0f1c4..d7d53c2a 100644 --- a/backend/internal/service/sora_media_cleanup_service.go +++ b/backend/internal/service/sora_media_cleanup_service.go @@ -85,6 +85,9 @@ func (s *SoraMediaCleanupService) Stop() { } func (s *SoraMediaCleanupService) runCleanup() { + if s.cfg == nil || s.storage == nil { + return + } retention := s.cfg.Sora.Storage.Cleanup.RetentionDays if retention <= 0 { log.Printf("[SoraCleanup] skipped (retention_days=%d)", retention) diff --git a/backend/internal/service/subscription_maintenance_queue.go b/backend/internal/service/subscription_maintenance_queue.go index 52ad6472..35bf18f3 100644 --- a/backend/internal/service/subscription_maintenance_queue.go +++ b/backend/internal/service/subscription_maintenance_queue.go @@ -6,12 +6,14 @@ import ( "sync" ) -// SubscriptionMaintenanceQueue 提供“有界队列 + 固定 worker”的后台执行器。 +// SubscriptionMaintenanceQueue 提供"有界队列 + 固定 worker"的后台执行器。 // 用于从请求热路径触发维护动作时,避免无限 goroutine 膨胀。 type SubscriptionMaintenanceQueue struct { - queue chan func() - wg sync.WaitGroup - stop sync.Once + queue chan func() + wg sync.WaitGroup + stop sync.Once + mu sync.RWMutex // 保护 closed 标志与 channel 操作的原子性 + closed bool } func NewSubscriptionMaintenanceQueue(workerCount, queueSize int) *SubscriptionMaintenanceQueue { @@ -48,6 +50,7 @@ func NewSubscriptionMaintenanceQueue(workerCount, queueSize int) *SubscriptionMa // TryEnqueue 尝试将任务入队。 // 当队列已满时返回 error(调用方应该选择跳过并记录告警/限频日志)。 +// 当队列已关闭时返回 error,不会 panic。 func (q *SubscriptionMaintenanceQueue) TryEnqueue(task func()) error { if q == nil { return fmt.Errorf("maintenance queue is nil") @@ -56,6 +59,13 @@ func (q *SubscriptionMaintenanceQueue) TryEnqueue(task func()) error { return fmt.Errorf("maintenance task is nil") } + q.mu.RLock() + defer q.mu.RUnlock() + + if q.closed { + return fmt.Errorf("maintenance queue stopped") + } + select { case q.queue <- task: return nil @@ -69,7 +79,10 @@ func (q *SubscriptionMaintenanceQueue) Stop() { return } q.stop.Do(func() { + q.mu.Lock() + q.closed = true close(q.queue) + q.mu.Unlock() q.wg.Wait() }) } diff --git a/backend/internal/service/timing_wheel_service.go b/backend/internal/service/timing_wheel_service.go index 5a2dea75..a08c80a8 100644 --- a/backend/internal/service/timing_wheel_service.go +++ b/backend/internal/service/timing_wheel_service.go @@ -47,7 +47,9 @@ func (s *TimingWheelService) Stop() { // Schedule schedules a one-time task func (s *TimingWheelService) Schedule(name string, delay time.Duration, fn func()) { - _ = s.tw.SetTimer(name, fn, delay) + if err := s.tw.SetTimer(name, fn, delay); err != nil { + log.Printf("[TimingWheel] SetTimer failed for %q: %v", name, err) + } } // ScheduleRecurring schedules a recurring task @@ -55,9 +57,13 @@ func (s *TimingWheelService) ScheduleRecurring(name string, interval time.Durati var schedule func() schedule = func() { fn() - _ = s.tw.SetTimer(name, schedule, interval) + if err := s.tw.SetTimer(name, schedule, interval); err != nil { + log.Printf("[TimingWheel] recurring SetTimer failed for %q: %v", name, err) + } + } + if err := s.tw.SetTimer(name, schedule, interval); err != nil { + log.Printf("[TimingWheel] initial SetTimer failed for %q: %v", name, err) } - _ = s.tw.SetTimer(name, schedule, interval) } // Cancel cancels a scheduled task