From fe71ee57b3b78d36f20dc833f7e4c7f97a3728be Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Fri, 16 Jan 2026 15:25:33 +0800 Subject: [PATCH] =?UTF-8?q?fix(=E5=AE=9A=E6=97=B6=E8=BD=AE):=20=E5=88=9D?= =?UTF-8?q?=E5=A7=8B=E5=8C=96=E5=A4=B1=E8=B4=A5=E8=BF=94=E5=9B=9E=E9=94=99?= =?UTF-8?q?=E8=AF=AF=E5=B9=B6=E8=A1=A5=E5=85=85=E5=8D=95=E6=B5=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - NewTimingWheelService 改为返回 error,避免 panic - ProvideTimingWheelService 透传 error 并更新 wire 生成代码 - 补充定时任务调度/取消/周期任务相关单元测试 --- backend/cmd/server/wire_gen.go | 5 +- .../internal/service/timing_wheel_service.go | 11 +- .../service/timing_wheel_service_test.go | 146 ++++++++++++++++++ backend/internal/service/wire.go | 9 +- backend/internal/service/wire_test.go | 37 +++++ 5 files changed, 200 insertions(+), 8 deletions(-) create mode 100644 backend/internal/service/timing_wheel_service_test.go create mode 100644 backend/internal/service/wire_test.go diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 289a14bd..27404b02 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -78,7 +78,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { dashboardAggregationRepository := repository.NewDashboardAggregationRepository(db) dashboardStatsCache := repository.NewDashboardCache(redisClient, configConfig) dashboardService := service.NewDashboardService(usageLogRepository, dashboardAggregationRepository, dashboardStatsCache, configConfig) - timingWheelService := service.ProvideTimingWheelService() + timingWheelService, err := service.ProvideTimingWheelService() + if err != nil { + return nil, err + } dashboardAggregationService := service.ProvideDashboardAggregationService(dashboardAggregationRepository, timingWheelService, configConfig) dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService) accountRepository := repository.NewAccountRepository(client, db) diff --git a/backend/internal/service/timing_wheel_service.go b/backend/internal/service/timing_wheel_service.go index c4e64e33..5a2dea75 100644 --- a/backend/internal/service/timing_wheel_service.go +++ b/backend/internal/service/timing_wheel_service.go @@ -1,6 +1,7 @@ package service import ( + "fmt" "log" "sync" "time" @@ -8,6 +9,8 @@ import ( "github.com/zeromicro/go-zero/core/collection" ) +var newTimingWheel = collection.NewTimingWheel + // TimingWheelService wraps go-zero's TimingWheel for task scheduling type TimingWheelService struct { tw *collection.TimingWheel @@ -15,18 +18,18 @@ type TimingWheelService struct { } // NewTimingWheelService creates a new TimingWheelService instance -func NewTimingWheelService() *TimingWheelService { +func NewTimingWheelService() (*TimingWheelService, error) { // 1 second tick, 3600 slots = supports up to 1 hour delay // execute function: runs func() type tasks - tw, err := collection.NewTimingWheel(1*time.Second, 3600, func(key, value any) { + tw, err := newTimingWheel(1*time.Second, 3600, func(key, value any) { if fn, ok := value.(func()); ok { fn() } }) if err != nil { - panic(err) + return nil, fmt.Errorf("创建 timing wheel 失败: %w", err) } - return &TimingWheelService{tw: tw} + return &TimingWheelService{tw: tw}, nil } // Start starts the timing wheel diff --git a/backend/internal/service/timing_wheel_service_test.go b/backend/internal/service/timing_wheel_service_test.go new file mode 100644 index 00000000..cd0bffb7 --- /dev/null +++ b/backend/internal/service/timing_wheel_service_test.go @@ -0,0 +1,146 @@ +package service + +import ( + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/zeromicro/go-zero/core/collection" +) + +func TestNewTimingWheelService_InitFail_NoPanicAndReturnError(t *testing.T) { + original := newTimingWheel + t.Cleanup(func() { newTimingWheel = original }) + + newTimingWheel = func(_ time.Duration, _ int, _ collection.Execute) (*collection.TimingWheel, error) { + return nil, errors.New("boom") + } + + svc, err := NewTimingWheelService() + if err == nil { + t.Fatalf("期望返回 error,但得到 nil") + } + if svc != nil { + t.Fatalf("期望返回 nil svc,但得到非空") + } +} + +func TestNewTimingWheelService_Success(t *testing.T) { + svc, err := NewTimingWheelService() + if err != nil { + t.Fatalf("期望 err 为 nil,但得到: %v", err) + } + if svc == nil { + t.Fatalf("期望 svc 非空,但得到 nil") + } + svc.Stop() +} + +func TestNewTimingWheelService_ExecuteCallbackRunsFunc(t *testing.T) { + original := newTimingWheel + t.Cleanup(func() { newTimingWheel = original }) + + var captured collection.Execute + newTimingWheel = func(interval time.Duration, numSlots int, execute collection.Execute) (*collection.TimingWheel, error) { + captured = execute + return original(interval, numSlots, execute) + } + + svc, err := NewTimingWheelService() + if err != nil { + t.Fatalf("期望 err 为 nil,但得到: %v", err) + } + if captured == nil { + t.Fatalf("期望 captured 非空,但得到 nil") + } + + called := false + captured("k", func() { called = true }) + if !called { + t.Fatalf("期望 execute 回调触发传入函数执行") + } + + svc.Stop() +} + +func TestTimingWheelService_Schedule_ExecutesOnce(t *testing.T) { + original := newTimingWheel + t.Cleanup(func() { newTimingWheel = original }) + + newTimingWheel = func(_ time.Duration, _ int, execute collection.Execute) (*collection.TimingWheel, error) { + return original(10*time.Millisecond, 128, execute) + } + + svc, err := NewTimingWheelService() + if err != nil { + t.Fatalf("期望 err 为 nil,但得到: %v", err) + } + defer svc.Stop() + + ch := make(chan struct{}, 1) + svc.Schedule("once", 30*time.Millisecond, func() { ch <- struct{}{} }) + + select { + case <-ch: + case <-time.After(500 * time.Millisecond): + t.Fatalf("等待任务执行超时") + } + + select { + case <-ch: + t.Fatalf("任务不应重复执行") + case <-time.After(80 * time.Millisecond): + } +} + +func TestTimingWheelService_Cancel_PreventsExecution(t *testing.T) { + original := newTimingWheel + t.Cleanup(func() { newTimingWheel = original }) + + newTimingWheel = func(_ time.Duration, _ int, execute collection.Execute) (*collection.TimingWheel, error) { + return original(10*time.Millisecond, 128, execute) + } + + svc, err := NewTimingWheelService() + if err != nil { + t.Fatalf("期望 err 为 nil,但得到: %v", err) + } + defer svc.Stop() + + ch := make(chan struct{}, 1) + svc.Schedule("cancel", 80*time.Millisecond, func() { ch <- struct{}{} }) + svc.Cancel("cancel") + + select { + case <-ch: + t.Fatalf("任务已取消,不应执行") + case <-time.After(200 * time.Millisecond): + } +} + +func TestTimingWheelService_ScheduleRecurring_ExecutesMultipleTimes(t *testing.T) { + original := newTimingWheel + t.Cleanup(func() { newTimingWheel = original }) + + newTimingWheel = func(_ time.Duration, _ int, execute collection.Execute) (*collection.TimingWheel, error) { + return original(10*time.Millisecond, 128, execute) + } + + svc, err := NewTimingWheelService() + if err != nil { + t.Fatalf("期望 err 为 nil,但得到: %v", err) + } + defer svc.Stop() + + var count int32 + svc.ScheduleRecurring("rec", 30*time.Millisecond, func() { atomic.AddInt32(&count, 1) }) + + deadline := time.Now().Add(500 * time.Millisecond) + for atomic.LoadInt32(&count) < 2 && time.Now().Before(deadline) { + time.Sleep(10 * time.Millisecond) + } + if atomic.LoadInt32(&count) < 2 { + t.Fatalf("期望周期任务至少执行 2 次,但只执行了 %d 次", atomic.LoadInt32(&count)) + } +} diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 5ba093a4..acc0a5fb 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -65,10 +65,13 @@ func ProvideAccountExpiryService(accountRepo AccountRepository) *AccountExpirySe } // ProvideTimingWheelService creates and starts TimingWheelService -func ProvideTimingWheelService() *TimingWheelService { - svc := NewTimingWheelService() +func ProvideTimingWheelService() (*TimingWheelService, error) { + svc, err := NewTimingWheelService() + if err != nil { + return nil, err + } svc.Start() - return svc + return svc, nil } // ProvideDeferredService creates and starts DeferredService diff --git a/backend/internal/service/wire_test.go b/backend/internal/service/wire_test.go new file mode 100644 index 00000000..5f7866f6 --- /dev/null +++ b/backend/internal/service/wire_test.go @@ -0,0 +1,37 @@ +package service + +import ( + "errors" + "testing" + "time" + + "github.com/zeromicro/go-zero/core/collection" +) + +func TestProvideTimingWheelService_ReturnsError(t *testing.T) { + original := newTimingWheel + t.Cleanup(func() { newTimingWheel = original }) + + newTimingWheel = func(_ time.Duration, _ int, _ collection.Execute) (*collection.TimingWheel, error) { + return nil, errors.New("boom") + } + + svc, err := ProvideTimingWheelService() + if err == nil { + t.Fatalf("期望返回 error,但得到 nil") + } + if svc != nil { + t.Fatalf("期望返回 nil svc,但得到非空") + } +} + +func TestProvideTimingWheelService_Success(t *testing.T) { + svc, err := ProvideTimingWheelService() + if err != nil { + t.Fatalf("期望 err 为 nil,但得到: %v", err) + } + if svc == nil { + t.Fatalf("期望 svc 非空,但得到 nil") + } + svc.Stop() +}