Return bad request for invalid announcements

This commit is contained in:
IanShaw027
2026-04-21 09:52:20 -07:00
parent 0d87f94cb7
commit 89d09838d8
3 changed files with 111 additions and 14 deletions

View File

@@ -5,6 +5,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/domain"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
@@ -34,8 +35,23 @@ const (
)
var (
ErrAnnouncementNotFound = domain.ErrAnnouncementNotFound
ErrAnnouncementInvalidTarget = domain.ErrAnnouncementInvalidTarget
ErrAnnouncementNotFound = domain.ErrAnnouncementNotFound
ErrAnnouncementInvalidTarget = domain.ErrAnnouncementInvalidTarget
ErrAnnouncementNilInput = infraerrors.BadRequest("ANNOUNCEMENT_INPUT_REQUIRED", "announcement input is required")
ErrAnnouncementInvalidTitle = infraerrors.BadRequest("ANNOUNCEMENT_TITLE_INVALID", "announcement title is invalid")
ErrAnnouncementContentRequired = infraerrors.BadRequest(
"ANNOUNCEMENT_CONTENT_REQUIRED",
"announcement content is required",
)
ErrAnnouncementInvalidStatus = infraerrors.BadRequest("ANNOUNCEMENT_STATUS_INVALID", "announcement status is invalid")
ErrAnnouncementInvalidNotifyMode = infraerrors.BadRequest(
"ANNOUNCEMENT_NOTIFY_MODE_INVALID",
"announcement notify_mode is invalid",
)
ErrAnnouncementInvalidSchedule = infraerrors.BadRequest(
"ANNOUNCEMENT_TIME_RANGE_INVALID",
"starts_at must be before ends_at",
)
)
type AnnouncementTargeting = domain.AnnouncementTargeting

View File

@@ -70,16 +70,16 @@ type AnnouncementUserReadStatus struct {
func (s *AnnouncementService) Create(ctx context.Context, input *CreateAnnouncementInput) (*Announcement, error) {
if input == nil {
return nil, fmt.Errorf("create announcement: nil input")
return nil, ErrAnnouncementNilInput
}
title := strings.TrimSpace(input.Title)
content := strings.TrimSpace(input.Content)
if title == "" || len(title) > 200 {
return nil, fmt.Errorf("create announcement: invalid title")
return nil, ErrAnnouncementInvalidTitle
}
if content == "" {
return nil, fmt.Errorf("create announcement: content is required")
return nil, ErrAnnouncementContentRequired
}
status := strings.TrimSpace(input.Status)
@@ -87,7 +87,7 @@ func (s *AnnouncementService) Create(ctx context.Context, input *CreateAnnouncem
status = AnnouncementStatusDraft
}
if !isValidAnnouncementStatus(status) {
return nil, fmt.Errorf("create announcement: invalid status")
return nil, ErrAnnouncementInvalidStatus
}
targeting, err := domain.AnnouncementTargeting(input.Targeting).NormalizeAndValidate()
@@ -100,12 +100,12 @@ func (s *AnnouncementService) Create(ctx context.Context, input *CreateAnnouncem
notifyMode = AnnouncementNotifyModeSilent
}
if !isValidAnnouncementNotifyMode(notifyMode) {
return nil, fmt.Errorf("create announcement: invalid notify_mode")
return nil, ErrAnnouncementInvalidNotifyMode
}
if input.StartsAt != nil && input.EndsAt != nil {
if !input.StartsAt.Before(*input.EndsAt) {
return nil, fmt.Errorf("create announcement: starts_at must be before ends_at")
return nil, ErrAnnouncementInvalidSchedule
}
}
@@ -131,7 +131,7 @@ func (s *AnnouncementService) Create(ctx context.Context, input *CreateAnnouncem
func (s *AnnouncementService) Update(ctx context.Context, id int64, input *UpdateAnnouncementInput) (*Announcement, error) {
if input == nil {
return nil, fmt.Errorf("update announcement: nil input")
return nil, ErrAnnouncementNilInput
}
a, err := s.announcementRepo.GetByID(ctx, id)
@@ -142,21 +142,21 @@ func (s *AnnouncementService) Update(ctx context.Context, id int64, input *Updat
if input.Title != nil {
title := strings.TrimSpace(*input.Title)
if title == "" || len(title) > 200 {
return nil, fmt.Errorf("update announcement: invalid title")
return nil, ErrAnnouncementInvalidTitle
}
a.Title = title
}
if input.Content != nil {
content := strings.TrimSpace(*input.Content)
if content == "" {
return nil, fmt.Errorf("update announcement: content is required")
return nil, ErrAnnouncementContentRequired
}
a.Content = content
}
if input.Status != nil {
status := strings.TrimSpace(*input.Status)
if !isValidAnnouncementStatus(status) {
return nil, fmt.Errorf("update announcement: invalid status")
return nil, ErrAnnouncementInvalidStatus
}
a.Status = status
}
@@ -164,7 +164,7 @@ func (s *AnnouncementService) Update(ctx context.Context, id int64, input *Updat
if input.NotifyMode != nil {
notifyMode := strings.TrimSpace(*input.NotifyMode)
if !isValidAnnouncementNotifyMode(notifyMode) {
return nil, fmt.Errorf("update announcement: invalid notify_mode")
return nil, ErrAnnouncementInvalidNotifyMode
}
a.NotifyMode = notifyMode
}
@@ -186,7 +186,7 @@ func (s *AnnouncementService) Update(ctx context.Context, id int64, input *Updat
if a.StartsAt != nil && a.EndsAt != nil {
if !a.StartsAt.Before(*a.EndsAt) {
return nil, fmt.Errorf("update announcement: starts_at must be before ends_at")
return nil, ErrAnnouncementInvalidSchedule
}
}

View File

@@ -0,0 +1,81 @@
package service
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
type announcementRepoStub struct {
item *Announcement
}
func (s *announcementRepoStub) Create(_ context.Context, a *Announcement) error {
s.item = a
return nil
}
func (s *announcementRepoStub) GetByID(_ context.Context, _ int64) (*Announcement, error) {
if s.item == nil {
return nil, ErrAnnouncementNotFound
}
return s.item, nil
}
func (s *announcementRepoStub) Update(_ context.Context, a *Announcement) error {
s.item = a
return nil
}
func (*announcementRepoStub) Delete(context.Context, int64) error {
return nil
}
func (*announcementRepoStub) List(context.Context, pagination.PaginationParams, AnnouncementListFilters) ([]Announcement, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (*announcementRepoStub) ListActive(context.Context, time.Time) ([]Announcement, error) {
return nil, nil
}
func TestAnnouncementServiceCreateRejectsEqualStartEndTimes(t *testing.T) {
repo := &announcementRepoStub{}
svc := NewAnnouncementService(repo, nil, nil, nil)
now := time.Unix(1776790020, 0)
_, err := svc.Create(context.Background(), &CreateAnnouncementInput{
Title: "公告",
Content: "内容",
Status: AnnouncementStatusActive,
NotifyMode: AnnouncementNotifyModePopup,
StartsAt: &now,
EndsAt: &now,
})
require.ErrorIs(t, err, ErrAnnouncementInvalidSchedule)
}
func TestAnnouncementServiceUpdateRejectsEqualStartEndTimes(t *testing.T) {
repo := &announcementRepoStub{
item: &Announcement{
ID: 1,
Title: "公告",
Content: "内容",
Status: AnnouncementStatusActive,
NotifyMode: AnnouncementNotifyModePopup,
},
}
svc := NewAnnouncementService(repo, nil, nil, nil)
now := time.Unix(1776790020, 0)
startsAt := &now
endsAt := &now
_, err := svc.Update(context.Background(), 1, &UpdateAnnouncementInput{
StartsAt: &startsAt,
EndsAt: &endsAt,
})
require.ErrorIs(t, err, ErrAnnouncementInvalidSchedule)
}