fix(backend): 修复代码审核发现的 8 个确认问题
- P0-1: subscription_maintenance_queue 使用 RWMutex 防止 channel close/send 竞态 - P0-2: billing_service CalculateCostWithLongContext 修复被吞没的 out-range 错误 - P1-1: timing_wheel_service Schedule/ScheduleRecurring 添加 SetTimer 错误日志 - P1-2: sora_gateway_service StoreFromURLs 失败时降级使用原始 URL - P1-3: concurrency_cache 用 Pipeline 替代 Lua 脚本兼容 Redis Cluster - P1-6: sora_media_cleanup_service runCleanup 添加 nil cfg/storage 防护 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -407,29 +407,53 @@ func (c *concurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []
|
|||||||
return map[int64]*service.AccountLoadInfo{}, nil
|
return map[int64]*service.AccountLoadInfo{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
args := []any{c.slotTTLSeconds}
|
// 使用 Pipeline 替代 Lua 脚本,兼容 Redis Cluster(Lua 内动态拼 key 会 CROSSSLOT)。
|
||||||
for _, acc := range accounts {
|
// 每个账号执行 3 个命令:ZREMRANGEBYSCORE(清理过期)、ZCARD(并发数)、GET(等待数)。
|
||||||
args = append(args, acc.ID, acc.MaxConcurrency)
|
now, err := c.rdb.Time(ctx).Result()
|
||||||
}
|
|
||||||
|
|
||||||
result, err := getAccountsLoadBatchScript.Run(ctx, c.rdb, []string{}, args...).Slice()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("redis TIME: %w", err)
|
||||||
|
}
|
||||||
|
cutoffTime := now.Unix() - int64(c.slotTTLSeconds)
|
||||||
|
|
||||||
|
pipe := c.rdb.Pipeline()
|
||||||
|
|
||||||
|
type accountCmds struct {
|
||||||
|
id int64
|
||||||
|
maxConcurrency int
|
||||||
|
zcardCmd *redis.IntCmd
|
||||||
|
getCmd *redis.StringCmd
|
||||||
|
}
|
||||||
|
cmds := make([]accountCmds, 0, len(accounts))
|
||||||
|
for _, acc := range accounts {
|
||||||
|
slotKey := accountSlotKeyPrefix + strconv.FormatInt(acc.ID, 10)
|
||||||
|
waitKey := accountWaitKeyPrefix + strconv.FormatInt(acc.ID, 10)
|
||||||
|
pipe.ZRemRangeByScore(ctx, slotKey, "-inf", strconv.FormatInt(cutoffTime, 10))
|
||||||
|
ac := accountCmds{
|
||||||
|
id: acc.ID,
|
||||||
|
maxConcurrency: acc.MaxConcurrency,
|
||||||
|
zcardCmd: pipe.ZCard(ctx, slotKey),
|
||||||
|
getCmd: pipe.Get(ctx, waitKey),
|
||||||
|
}
|
||||||
|
cmds = append(cmds, ac)
|
||||||
}
|
}
|
||||||
|
|
||||||
loadMap := make(map[int64]*service.AccountLoadInfo)
|
if _, err := pipe.Exec(ctx); err != nil && !errors.Is(err, redis.Nil) {
|
||||||
for i := 0; i < len(result); i += 4 {
|
return nil, fmt.Errorf("pipeline exec: %w", err)
|
||||||
if i+3 >= len(result) {
|
|
||||||
break
|
|
||||||
}
|
}
|
||||||
|
|
||||||
accountID, _ := strconv.ParseInt(fmt.Sprintf("%v", result[i]), 10, 64)
|
loadMap := make(map[int64]*service.AccountLoadInfo, len(accounts))
|
||||||
currentConcurrency, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+1]))
|
for _, ac := range cmds {
|
||||||
waitingCount, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+2]))
|
currentConcurrency := int(ac.zcardCmd.Val())
|
||||||
loadRate, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+3]))
|
waitingCount := 0
|
||||||
|
if v, err := ac.getCmd.Int(); err == nil {
|
||||||
loadMap[accountID] = &service.AccountLoadInfo{
|
waitingCount = v
|
||||||
AccountID: accountID,
|
}
|
||||||
|
loadRate := 0
|
||||||
|
if ac.maxConcurrency > 0 {
|
||||||
|
loadRate = (currentConcurrency + waitingCount) * 100 / ac.maxConcurrency
|
||||||
|
}
|
||||||
|
loadMap[ac.id] = &service.AccountLoadInfo{
|
||||||
|
AccountID: ac.id,
|
||||||
CurrentConcurrency: currentConcurrency,
|
CurrentConcurrency: currentConcurrency,
|
||||||
WaitingCount: waitingCount,
|
WaitingCount: waitingCount,
|
||||||
LoadRate: loadRate,
|
LoadRate: loadRate,
|
||||||
@@ -444,29 +468,52 @@ func (c *concurrencyCache) GetUsersLoadBatch(ctx context.Context, users []servic
|
|||||||
return map[int64]*service.UserLoadInfo{}, nil
|
return map[int64]*service.UserLoadInfo{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
args := []any{c.slotTTLSeconds}
|
// 使用 Pipeline 替代 Lua 脚本,兼容 Redis Cluster。
|
||||||
for _, u := range users {
|
now, err := c.rdb.Time(ctx).Result()
|
||||||
args = append(args, u.ID, u.MaxConcurrency)
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := getUsersLoadBatchScript.Run(ctx, c.rdb, []string{}, args...).Slice()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("redis TIME: %w", err)
|
||||||
|
}
|
||||||
|
cutoffTime := now.Unix() - int64(c.slotTTLSeconds)
|
||||||
|
|
||||||
|
pipe := c.rdb.Pipeline()
|
||||||
|
|
||||||
|
type userCmds struct {
|
||||||
|
id int64
|
||||||
|
maxConcurrency int
|
||||||
|
zcardCmd *redis.IntCmd
|
||||||
|
getCmd *redis.StringCmd
|
||||||
|
}
|
||||||
|
cmds := make([]userCmds, 0, len(users))
|
||||||
|
for _, u := range users {
|
||||||
|
slotKey := userSlotKeyPrefix + strconv.FormatInt(u.ID, 10)
|
||||||
|
waitKey := waitQueueKeyPrefix + strconv.FormatInt(u.ID, 10)
|
||||||
|
pipe.ZRemRangeByScore(ctx, slotKey, "-inf", strconv.FormatInt(cutoffTime, 10))
|
||||||
|
uc := userCmds{
|
||||||
|
id: u.ID,
|
||||||
|
maxConcurrency: u.MaxConcurrency,
|
||||||
|
zcardCmd: pipe.ZCard(ctx, slotKey),
|
||||||
|
getCmd: pipe.Get(ctx, waitKey),
|
||||||
|
}
|
||||||
|
cmds = append(cmds, uc)
|
||||||
}
|
}
|
||||||
|
|
||||||
loadMap := make(map[int64]*service.UserLoadInfo)
|
if _, err := pipe.Exec(ctx); err != nil && !errors.Is(err, redis.Nil) {
|
||||||
for i := 0; i < len(result); i += 4 {
|
return nil, fmt.Errorf("pipeline exec: %w", err)
|
||||||
if i+3 >= len(result) {
|
|
||||||
break
|
|
||||||
}
|
}
|
||||||
|
|
||||||
userID, _ := strconv.ParseInt(fmt.Sprintf("%v", result[i]), 10, 64)
|
loadMap := make(map[int64]*service.UserLoadInfo, len(users))
|
||||||
currentConcurrency, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+1]))
|
for _, uc := range cmds {
|
||||||
waitingCount, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+2]))
|
currentConcurrency := int(uc.zcardCmd.Val())
|
||||||
loadRate, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+3]))
|
waitingCount := 0
|
||||||
|
if v, err := uc.getCmd.Int(); err == nil {
|
||||||
loadMap[userID] = &service.UserLoadInfo{
|
waitingCount = v
|
||||||
UserID: userID,
|
}
|
||||||
|
loadRate := 0
|
||||||
|
if uc.maxConcurrency > 0 {
|
||||||
|
loadRate = (currentConcurrency + waitingCount) * 100 / uc.maxConcurrency
|
||||||
|
}
|
||||||
|
loadMap[uc.id] = &service.UserLoadInfo{
|
||||||
|
UserID: uc.id,
|
||||||
CurrentConcurrency: currentConcurrency,
|
CurrentConcurrency: currentConcurrency,
|
||||||
WaitingCount: waitingCount,
|
WaitingCount: waitingCount,
|
||||||
LoadRate: loadRate,
|
LoadRate: loadRate,
|
||||||
|
|||||||
@@ -297,7 +297,7 @@ func (s *BillingService) CalculateCostWithLongContext(model string, tokens Usage
|
|||||||
}
|
}
|
||||||
outRangeCost, err := s.CalculateCost(model, outRangeTokens, rateMultiplier*extraMultiplier)
|
outRangeCost, err := s.CalculateCost(model, outRangeTokens, rateMultiplier*extraMultiplier)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return inRangeCost, nil // 出错时返回范围内成本
|
return inRangeCost, fmt.Errorf("out-range cost: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 合并成本
|
// 合并成本
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log"
|
||||||
"mime"
|
"mime"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -210,10 +211,12 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
|
|||||||
if len(mediaURLs) > 0 && s.mediaStorage != nil && s.mediaStorage.Enabled() {
|
if len(mediaURLs) > 0 && s.mediaStorage != nil && s.mediaStorage.Enabled() {
|
||||||
stored, storeErr := s.mediaStorage.StoreFromURLs(reqCtx, mediaType, mediaURLs)
|
stored, storeErr := s.mediaStorage.StoreFromURLs(reqCtx, mediaType, mediaURLs)
|
||||||
if storeErr != nil {
|
if storeErr != nil {
|
||||||
return nil, s.handleSoraRequestError(ctx, account, storeErr, reqModel, c, clientStream)
|
// 存储失败时降级使用原始 URL,不中断用户请求
|
||||||
}
|
log.Printf("[Sora] StoreFromURLs failed, falling back to original URLs: %v", storeErr)
|
||||||
|
} else {
|
||||||
finalURLs = s.normalizeSoraMediaURLs(stored)
|
finalURLs = s.normalizeSoraMediaURLs(stored)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
content := buildSoraContent(mediaType, finalURLs)
|
content := buildSoraContent(mediaType, finalURLs)
|
||||||
var firstTokenMs *int
|
var firstTokenMs *int
|
||||||
|
|||||||
@@ -85,6 +85,9 @@ func (s *SoraMediaCleanupService) Stop() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *SoraMediaCleanupService) runCleanup() {
|
func (s *SoraMediaCleanupService) runCleanup() {
|
||||||
|
if s.cfg == nil || s.storage == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
retention := s.cfg.Sora.Storage.Cleanup.RetentionDays
|
retention := s.cfg.Sora.Storage.Cleanup.RetentionDays
|
||||||
if retention <= 0 {
|
if retention <= 0 {
|
||||||
log.Printf("[SoraCleanup] skipped (retention_days=%d)", retention)
|
log.Printf("[SoraCleanup] skipped (retention_days=%d)", retention)
|
||||||
|
|||||||
@@ -6,12 +6,14 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SubscriptionMaintenanceQueue 提供“有界队列 + 固定 worker”的后台执行器。
|
// SubscriptionMaintenanceQueue 提供"有界队列 + 固定 worker"的后台执行器。
|
||||||
// 用于从请求热路径触发维护动作时,避免无限 goroutine 膨胀。
|
// 用于从请求热路径触发维护动作时,避免无限 goroutine 膨胀。
|
||||||
type SubscriptionMaintenanceQueue struct {
|
type SubscriptionMaintenanceQueue struct {
|
||||||
queue chan func()
|
queue chan func()
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
stop sync.Once
|
stop sync.Once
|
||||||
|
mu sync.RWMutex // 保护 closed 标志与 channel 操作的原子性
|
||||||
|
closed bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSubscriptionMaintenanceQueue(workerCount, queueSize int) *SubscriptionMaintenanceQueue {
|
func NewSubscriptionMaintenanceQueue(workerCount, queueSize int) *SubscriptionMaintenanceQueue {
|
||||||
@@ -48,6 +50,7 @@ func NewSubscriptionMaintenanceQueue(workerCount, queueSize int) *SubscriptionMa
|
|||||||
|
|
||||||
// TryEnqueue 尝试将任务入队。
|
// TryEnqueue 尝试将任务入队。
|
||||||
// 当队列已满时返回 error(调用方应该选择跳过并记录告警/限频日志)。
|
// 当队列已满时返回 error(调用方应该选择跳过并记录告警/限频日志)。
|
||||||
|
// 当队列已关闭时返回 error,不会 panic。
|
||||||
func (q *SubscriptionMaintenanceQueue) TryEnqueue(task func()) error {
|
func (q *SubscriptionMaintenanceQueue) TryEnqueue(task func()) error {
|
||||||
if q == nil {
|
if q == nil {
|
||||||
return fmt.Errorf("maintenance queue is nil")
|
return fmt.Errorf("maintenance queue is nil")
|
||||||
@@ -56,6 +59,13 @@ func (q *SubscriptionMaintenanceQueue) TryEnqueue(task func()) error {
|
|||||||
return fmt.Errorf("maintenance task is nil")
|
return fmt.Errorf("maintenance task is nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
q.mu.RLock()
|
||||||
|
defer q.mu.RUnlock()
|
||||||
|
|
||||||
|
if q.closed {
|
||||||
|
return fmt.Errorf("maintenance queue stopped")
|
||||||
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case q.queue <- task:
|
case q.queue <- task:
|
||||||
return nil
|
return nil
|
||||||
@@ -69,7 +79,10 @@ func (q *SubscriptionMaintenanceQueue) Stop() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
q.stop.Do(func() {
|
q.stop.Do(func() {
|
||||||
|
q.mu.Lock()
|
||||||
|
q.closed = true
|
||||||
close(q.queue)
|
close(q.queue)
|
||||||
|
q.mu.Unlock()
|
||||||
q.wg.Wait()
|
q.wg.Wait()
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -47,7 +47,9 @@ func (s *TimingWheelService) Stop() {
|
|||||||
|
|
||||||
// Schedule schedules a one-time task
|
// Schedule schedules a one-time task
|
||||||
func (s *TimingWheelService) Schedule(name string, delay time.Duration, fn func()) {
|
func (s *TimingWheelService) Schedule(name string, delay time.Duration, fn func()) {
|
||||||
_ = s.tw.SetTimer(name, fn, delay)
|
if err := s.tw.SetTimer(name, fn, delay); err != nil {
|
||||||
|
log.Printf("[TimingWheel] SetTimer failed for %q: %v", name, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ScheduleRecurring schedules a recurring task
|
// ScheduleRecurring schedules a recurring task
|
||||||
@@ -55,9 +57,13 @@ func (s *TimingWheelService) ScheduleRecurring(name string, interval time.Durati
|
|||||||
var schedule func()
|
var schedule func()
|
||||||
schedule = func() {
|
schedule = func() {
|
||||||
fn()
|
fn()
|
||||||
_ = s.tw.SetTimer(name, schedule, interval)
|
if err := s.tw.SetTimer(name, schedule, interval); err != nil {
|
||||||
|
log.Printf("[TimingWheel] recurring SetTimer failed for %q: %v", name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := s.tw.SetTimer(name, schedule, interval); err != nil {
|
||||||
|
log.Printf("[TimingWheel] initial SetTimer failed for %q: %v", name, err)
|
||||||
}
|
}
|
||||||
_ = s.tw.SetTimer(name, schedule, interval)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cancel cancels a scheduled task
|
// Cancel cancels a scheduled task
|
||||||
|
|||||||
Reference in New Issue
Block a user