diff --git a/backend/internal/service/announcement.go b/backend/internal/service/announcement.go index 25c66eb4..02741d37 100644 --- a/backend/internal/service/announcement.go +++ b/backend/internal/service/announcement.go @@ -5,6 +5,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/domain" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" ) @@ -34,8 +35,23 @@ const ( ) var ( - ErrAnnouncementNotFound = domain.ErrAnnouncementNotFound - ErrAnnouncementInvalidTarget = domain.ErrAnnouncementInvalidTarget + ErrAnnouncementNotFound = domain.ErrAnnouncementNotFound + ErrAnnouncementInvalidTarget = domain.ErrAnnouncementInvalidTarget + ErrAnnouncementNilInput = infraerrors.BadRequest("ANNOUNCEMENT_INPUT_REQUIRED", "announcement input is required") + ErrAnnouncementInvalidTitle = infraerrors.BadRequest("ANNOUNCEMENT_TITLE_INVALID", "announcement title is invalid") + ErrAnnouncementContentRequired = infraerrors.BadRequest( + "ANNOUNCEMENT_CONTENT_REQUIRED", + "announcement content is required", + ) + ErrAnnouncementInvalidStatus = infraerrors.BadRequest("ANNOUNCEMENT_STATUS_INVALID", "announcement status is invalid") + ErrAnnouncementInvalidNotifyMode = infraerrors.BadRequest( + "ANNOUNCEMENT_NOTIFY_MODE_INVALID", + "announcement notify_mode is invalid", + ) + ErrAnnouncementInvalidSchedule = infraerrors.BadRequest( + "ANNOUNCEMENT_TIME_RANGE_INVALID", + "starts_at must be before ends_at", + ) ) type AnnouncementTargeting = domain.AnnouncementTargeting diff --git a/backend/internal/service/announcement_service.go b/backend/internal/service/announcement_service.go index c0a0681a..12479041 100644 --- a/backend/internal/service/announcement_service.go +++ b/backend/internal/service/announcement_service.go @@ -70,16 +70,16 @@ type AnnouncementUserReadStatus struct { func (s *AnnouncementService) Create(ctx context.Context, input *CreateAnnouncementInput) (*Announcement, error) { if input == nil { - return nil, fmt.Errorf("create announcement: nil input") + return nil, ErrAnnouncementNilInput } title := strings.TrimSpace(input.Title) content := strings.TrimSpace(input.Content) if title == "" || len(title) > 200 { - return nil, fmt.Errorf("create announcement: invalid title") + return nil, ErrAnnouncementInvalidTitle } if content == "" { - return nil, fmt.Errorf("create announcement: content is required") + return nil, ErrAnnouncementContentRequired } status := strings.TrimSpace(input.Status) @@ -87,7 +87,7 @@ func (s *AnnouncementService) Create(ctx context.Context, input *CreateAnnouncem status = AnnouncementStatusDraft } if !isValidAnnouncementStatus(status) { - return nil, fmt.Errorf("create announcement: invalid status") + return nil, ErrAnnouncementInvalidStatus } targeting, err := domain.AnnouncementTargeting(input.Targeting).NormalizeAndValidate() @@ -100,12 +100,12 @@ func (s *AnnouncementService) Create(ctx context.Context, input *CreateAnnouncem notifyMode = AnnouncementNotifyModeSilent } if !isValidAnnouncementNotifyMode(notifyMode) { - return nil, fmt.Errorf("create announcement: invalid notify_mode") + return nil, ErrAnnouncementInvalidNotifyMode } if input.StartsAt != nil && input.EndsAt != nil { if !input.StartsAt.Before(*input.EndsAt) { - return nil, fmt.Errorf("create announcement: starts_at must be before ends_at") + return nil, ErrAnnouncementInvalidSchedule } } @@ -131,7 +131,7 @@ func (s *AnnouncementService) Create(ctx context.Context, input *CreateAnnouncem func (s *AnnouncementService) Update(ctx context.Context, id int64, input *UpdateAnnouncementInput) (*Announcement, error) { if input == nil { - return nil, fmt.Errorf("update announcement: nil input") + return nil, ErrAnnouncementNilInput } a, err := s.announcementRepo.GetByID(ctx, id) @@ -142,21 +142,21 @@ func (s *AnnouncementService) Update(ctx context.Context, id int64, input *Updat if input.Title != nil { title := strings.TrimSpace(*input.Title) if title == "" || len(title) > 200 { - return nil, fmt.Errorf("update announcement: invalid title") + return nil, ErrAnnouncementInvalidTitle } a.Title = title } if input.Content != nil { content := strings.TrimSpace(*input.Content) if content == "" { - return nil, fmt.Errorf("update announcement: content is required") + return nil, ErrAnnouncementContentRequired } a.Content = content } if input.Status != nil { status := strings.TrimSpace(*input.Status) if !isValidAnnouncementStatus(status) { - return nil, fmt.Errorf("update announcement: invalid status") + return nil, ErrAnnouncementInvalidStatus } a.Status = status } @@ -164,7 +164,7 @@ func (s *AnnouncementService) Update(ctx context.Context, id int64, input *Updat if input.NotifyMode != nil { notifyMode := strings.TrimSpace(*input.NotifyMode) if !isValidAnnouncementNotifyMode(notifyMode) { - return nil, fmt.Errorf("update announcement: invalid notify_mode") + return nil, ErrAnnouncementInvalidNotifyMode } a.NotifyMode = notifyMode } @@ -186,7 +186,7 @@ func (s *AnnouncementService) Update(ctx context.Context, id int64, input *Updat if a.StartsAt != nil && a.EndsAt != nil { if !a.StartsAt.Before(*a.EndsAt) { - return nil, fmt.Errorf("update announcement: starts_at must be before ends_at") + return nil, ErrAnnouncementInvalidSchedule } } diff --git a/backend/internal/service/announcement_service_test.go b/backend/internal/service/announcement_service_test.go new file mode 100644 index 00000000..77fb9896 --- /dev/null +++ b/backend/internal/service/announcement_service_test.go @@ -0,0 +1,81 @@ +package service + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/require" +) + +type announcementRepoStub struct { + item *Announcement +} + +func (s *announcementRepoStub) Create(_ context.Context, a *Announcement) error { + s.item = a + return nil +} + +func (s *announcementRepoStub) GetByID(_ context.Context, _ int64) (*Announcement, error) { + if s.item == nil { + return nil, ErrAnnouncementNotFound + } + return s.item, nil +} + +func (s *announcementRepoStub) Update(_ context.Context, a *Announcement) error { + s.item = a + return nil +} + +func (*announcementRepoStub) Delete(context.Context, int64) error { + return nil +} + +func (*announcementRepoStub) List(context.Context, pagination.PaginationParams, AnnouncementListFilters) ([]Announcement, *pagination.PaginationResult, error) { + return nil, nil, nil +} + +func (*announcementRepoStub) ListActive(context.Context, time.Time) ([]Announcement, error) { + return nil, nil +} + +func TestAnnouncementServiceCreateRejectsEqualStartEndTimes(t *testing.T) { + repo := &announcementRepoStub{} + svc := NewAnnouncementService(repo, nil, nil, nil) + now := time.Unix(1776790020, 0) + + _, err := svc.Create(context.Background(), &CreateAnnouncementInput{ + Title: "公告", + Content: "内容", + Status: AnnouncementStatusActive, + NotifyMode: AnnouncementNotifyModePopup, + StartsAt: &now, + EndsAt: &now, + }) + require.ErrorIs(t, err, ErrAnnouncementInvalidSchedule) +} + +func TestAnnouncementServiceUpdateRejectsEqualStartEndTimes(t *testing.T) { + repo := &announcementRepoStub{ + item: &Announcement{ + ID: 1, + Title: "公告", + Content: "内容", + Status: AnnouncementStatusActive, + NotifyMode: AnnouncementNotifyModePopup, + }, + } + svc := NewAnnouncementService(repo, nil, nil, nil) + now := time.Unix(1776790020, 0) + startsAt := &now + endsAt := &now + + _, err := svc.Update(context.Background(), 1, &UpdateAnnouncementInput{ + StartsAt: &startsAt, + EndsAt: &endsAt, + }) + require.ErrorIs(t, err, ErrAnnouncementInvalidSchedule) +}