feat(announcements): add admin/user announcement system

Implements announcements end-to-end (admin CRUD + read status, user list + mark read) with OR-of-AND targeting. Also breaks the ent<->service import cycle by moving schema-facing constants/targeting into a new domain package.
This commit is contained in:
ducky
2026-01-30 16:45:04 +08:00
parent cadca752c4
commit b7f69844e1
70 changed files with 12366 additions and 71 deletions

View File

@@ -0,0 +1,64 @@
package service
import (
"context"
"time"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
const (
AnnouncementStatusDraft = domain.AnnouncementStatusDraft
AnnouncementStatusActive = domain.AnnouncementStatusActive
AnnouncementStatusArchived = domain.AnnouncementStatusArchived
)
const (
AnnouncementConditionTypeSubscription = domain.AnnouncementConditionTypeSubscription
AnnouncementConditionTypeBalance = domain.AnnouncementConditionTypeBalance
)
const (
AnnouncementOperatorIn = domain.AnnouncementOperatorIn
AnnouncementOperatorGT = domain.AnnouncementOperatorGT
AnnouncementOperatorGTE = domain.AnnouncementOperatorGTE
AnnouncementOperatorLT = domain.AnnouncementOperatorLT
AnnouncementOperatorLTE = domain.AnnouncementOperatorLTE
AnnouncementOperatorEQ = domain.AnnouncementOperatorEQ
)
var (
ErrAnnouncementNotFound = domain.ErrAnnouncementNotFound
ErrAnnouncementInvalidTarget = domain.ErrAnnouncementInvalidTarget
)
type AnnouncementTargeting = domain.AnnouncementTargeting
type AnnouncementConditionGroup = domain.AnnouncementConditionGroup
type AnnouncementCondition = domain.AnnouncementCondition
type Announcement = domain.Announcement
type AnnouncementListFilters struct {
Status string
Search string
}
type AnnouncementRepository interface {
Create(ctx context.Context, a *Announcement) error
GetByID(ctx context.Context, id int64) (*Announcement, error)
Update(ctx context.Context, a *Announcement) error
Delete(ctx context.Context, id int64) error
List(ctx context.Context, params pagination.PaginationParams, filters AnnouncementListFilters) ([]Announcement, *pagination.PaginationResult, error)
ListActive(ctx context.Context, now time.Time) ([]Announcement, error)
}
type AnnouncementReadRepository interface {
MarkRead(ctx context.Context, announcementID, userID int64, readAt time.Time) error
GetReadMapByUser(ctx context.Context, userID int64, announcementIDs []int64) (map[int64]time.Time, error)
GetReadMapByUsers(ctx context.Context, announcementID int64, userIDs []int64) (map[int64]time.Time, error)
CountByAnnouncementID(ctx context.Context, announcementID int64) (int64, error)
}

View File

@@ -0,0 +1,378 @@
package service
import (
"context"
"fmt"
"sort"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
type AnnouncementService struct {
announcementRepo AnnouncementRepository
readRepo AnnouncementReadRepository
userRepo UserRepository
userSubRepo UserSubscriptionRepository
}
func NewAnnouncementService(
announcementRepo AnnouncementRepository,
readRepo AnnouncementReadRepository,
userRepo UserRepository,
userSubRepo UserSubscriptionRepository,
) *AnnouncementService {
return &AnnouncementService{
announcementRepo: announcementRepo,
readRepo: readRepo,
userRepo: userRepo,
userSubRepo: userSubRepo,
}
}
type CreateAnnouncementInput struct {
Title string
Content string
Status string
Targeting AnnouncementTargeting
StartsAt *time.Time
EndsAt *time.Time
ActorID *int64 // 管理员用户ID
}
type UpdateAnnouncementInput struct {
Title *string
Content *string
Status *string
Targeting *AnnouncementTargeting
StartsAt **time.Time
EndsAt **time.Time
ActorID *int64 // 管理员用户ID
}
type UserAnnouncement struct {
Announcement Announcement
ReadAt *time.Time
}
type AnnouncementUserReadStatus struct {
UserID int64 `json:"user_id"`
Email string `json:"email"`
Username string `json:"username"`
Balance float64 `json:"balance"`
Eligible bool `json:"eligible"`
ReadAt *time.Time `json:"read_at,omitempty"`
}
func (s *AnnouncementService) Create(ctx context.Context, input *CreateAnnouncementInput) (*Announcement, error) {
if input == nil {
return nil, fmt.Errorf("create announcement: nil input")
}
title := strings.TrimSpace(input.Title)
content := strings.TrimSpace(input.Content)
if title == "" || len(title) > 200 {
return nil, fmt.Errorf("create announcement: invalid title")
}
if content == "" {
return nil, fmt.Errorf("create announcement: content is required")
}
status := strings.TrimSpace(input.Status)
if status == "" {
status = AnnouncementStatusDraft
}
if !isValidAnnouncementStatus(status) {
return nil, fmt.Errorf("create announcement: invalid status")
}
targeting, err := domain.AnnouncementTargeting(input.Targeting).NormalizeAndValidate()
if err != nil {
return nil, err
}
if input.StartsAt != nil && input.EndsAt != nil {
if !input.StartsAt.Before(*input.EndsAt) {
return nil, fmt.Errorf("create announcement: starts_at must be before ends_at")
}
}
a := &Announcement{
Title: title,
Content: content,
Status: status,
Targeting: targeting,
StartsAt: input.StartsAt,
EndsAt: input.EndsAt,
}
if input.ActorID != nil && *input.ActorID > 0 {
a.CreatedBy = input.ActorID
a.UpdatedBy = input.ActorID
}
if err := s.announcementRepo.Create(ctx, a); err != nil {
return nil, fmt.Errorf("create announcement: %w", err)
}
return a, nil
}
func (s *AnnouncementService) Update(ctx context.Context, id int64, input *UpdateAnnouncementInput) (*Announcement, error) {
if input == nil {
return nil, fmt.Errorf("update announcement: nil input")
}
a, err := s.announcementRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
if input.Title != nil {
title := strings.TrimSpace(*input.Title)
if title == "" || len(title) > 200 {
return nil, fmt.Errorf("update announcement: invalid title")
}
a.Title = title
}
if input.Content != nil {
content := strings.TrimSpace(*input.Content)
if content == "" {
return nil, fmt.Errorf("update announcement: content is required")
}
a.Content = content
}
if input.Status != nil {
status := strings.TrimSpace(*input.Status)
if !isValidAnnouncementStatus(status) {
return nil, fmt.Errorf("update announcement: invalid status")
}
a.Status = status
}
if input.Targeting != nil {
targeting, err := domain.AnnouncementTargeting(*input.Targeting).NormalizeAndValidate()
if err != nil {
return nil, err
}
a.Targeting = targeting
}
if input.StartsAt != nil {
a.StartsAt = *input.StartsAt
}
if input.EndsAt != nil {
a.EndsAt = *input.EndsAt
}
if a.StartsAt != nil && a.EndsAt != nil {
if !a.StartsAt.Before(*a.EndsAt) {
return nil, fmt.Errorf("update announcement: starts_at must be before ends_at")
}
}
if input.ActorID != nil && *input.ActorID > 0 {
a.UpdatedBy = input.ActorID
}
if err := s.announcementRepo.Update(ctx, a); err != nil {
return nil, fmt.Errorf("update announcement: %w", err)
}
return a, nil
}
func (s *AnnouncementService) Delete(ctx context.Context, id int64) error {
if err := s.announcementRepo.Delete(ctx, id); err != nil {
return fmt.Errorf("delete announcement: %w", err)
}
return nil
}
func (s *AnnouncementService) GetByID(ctx context.Context, id int64) (*Announcement, error) {
return s.announcementRepo.GetByID(ctx, id)
}
func (s *AnnouncementService) List(ctx context.Context, params pagination.PaginationParams, filters AnnouncementListFilters) ([]Announcement, *pagination.PaginationResult, error) {
return s.announcementRepo.List(ctx, params, filters)
}
func (s *AnnouncementService) ListForUser(ctx context.Context, userID int64, unreadOnly bool) ([]UserAnnouncement, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
activeSubs, err := s.userSubRepo.ListActiveByUserID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("list active subscriptions: %w", err)
}
activeGroupIDs := make(map[int64]struct{}, len(activeSubs))
for i := range activeSubs {
activeGroupIDs[activeSubs[i].GroupID] = struct{}{}
}
now := time.Now()
anns, err := s.announcementRepo.ListActive(ctx, now)
if err != nil {
return nil, fmt.Errorf("list active announcements: %w", err)
}
visible := make([]Announcement, 0, len(anns))
ids := make([]int64, 0, len(anns))
for i := range anns {
a := anns[i]
if !a.IsActiveAt(now) {
continue
}
if !a.Targeting.Matches(user.Balance, activeGroupIDs) {
continue
}
visible = append(visible, a)
ids = append(ids, a.ID)
}
if len(visible) == 0 {
return []UserAnnouncement{}, nil
}
readMap, err := s.readRepo.GetReadMapByUser(ctx, userID, ids)
if err != nil {
return nil, fmt.Errorf("get read map: %w", err)
}
out := make([]UserAnnouncement, 0, len(visible))
for i := range visible {
a := visible[i]
readAt, ok := readMap[a.ID]
if unreadOnly && ok {
continue
}
var ptr *time.Time
if ok {
t := readAt
ptr = &t
}
out = append(out, UserAnnouncement{
Announcement: a,
ReadAt: ptr,
})
}
// 未读优先、同状态按创建时间倒序
sort.Slice(out, func(i, j int) bool {
ai, aj := out[i], out[j]
if (ai.ReadAt == nil) != (aj.ReadAt == nil) {
return ai.ReadAt == nil
}
return ai.Announcement.ID > aj.Announcement.ID
})
return out, nil
}
func (s *AnnouncementService) MarkRead(ctx context.Context, userID, announcementID int64) error {
// 安全:仅允许标记当前用户“可见”的公告
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return fmt.Errorf("get user: %w", err)
}
a, err := s.announcementRepo.GetByID(ctx, announcementID)
if err != nil {
return err
}
now := time.Now()
if !a.IsActiveAt(now) {
return ErrAnnouncementNotFound
}
activeSubs, err := s.userSubRepo.ListActiveByUserID(ctx, userID)
if err != nil {
return fmt.Errorf("list active subscriptions: %w", err)
}
activeGroupIDs := make(map[int64]struct{}, len(activeSubs))
for i := range activeSubs {
activeGroupIDs[activeSubs[i].GroupID] = struct{}{}
}
if !a.Targeting.Matches(user.Balance, activeGroupIDs) {
return ErrAnnouncementNotFound
}
if err := s.readRepo.MarkRead(ctx, announcementID, userID, now); err != nil {
return fmt.Errorf("mark read: %w", err)
}
return nil
}
func (s *AnnouncementService) ListUserReadStatus(
ctx context.Context,
announcementID int64,
params pagination.PaginationParams,
search string,
) ([]AnnouncementUserReadStatus, *pagination.PaginationResult, error) {
ann, err := s.announcementRepo.GetByID(ctx, announcementID)
if err != nil {
return nil, nil, err
}
filters := UserListFilters{
Search: strings.TrimSpace(search),
}
users, page, err := s.userRepo.ListWithFilters(ctx, params, filters)
if err != nil {
return nil, nil, fmt.Errorf("list users: %w", err)
}
userIDs := make([]int64, 0, len(users))
for i := range users {
userIDs = append(userIDs, users[i].ID)
}
readMap, err := s.readRepo.GetReadMapByUsers(ctx, announcementID, userIDs)
if err != nil {
return nil, nil, fmt.Errorf("get read map: %w", err)
}
out := make([]AnnouncementUserReadStatus, 0, len(users))
for i := range users {
u := users[i]
subs, err := s.userSubRepo.ListActiveByUserID(ctx, u.ID)
if err != nil {
return nil, nil, fmt.Errorf("list active subscriptions: %w", err)
}
activeGroupIDs := make(map[int64]struct{}, len(subs))
for j := range subs {
activeGroupIDs[subs[j].GroupID] = struct{}{}
}
readAt, ok := readMap[u.ID]
var ptr *time.Time
if ok {
t := readAt
ptr = &t
}
out = append(out, AnnouncementUserReadStatus{
UserID: u.ID,
Email: u.Email,
Username: u.Username,
Balance: u.Balance,
Eligible: domain.AnnouncementTargeting(ann.Targeting).Matches(u.Balance, activeGroupIDs),
ReadAt: ptr,
})
}
return out, page, nil
}
func isValidAnnouncementStatus(status string) bool {
switch status {
case AnnouncementStatusDraft, AnnouncementStatusActive, AnnouncementStatusArchived:
return true
default:
return false
}
}

View File

@@ -0,0 +1,67 @@
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestAnnouncementTargeting_Matches_EmptyMatchesAll(t *testing.T) {
var targeting AnnouncementTargeting
require.True(t, targeting.Matches(0, nil))
require.True(t, targeting.Matches(123.45, map[int64]struct{}{1: {}}))
}
func TestAnnouncementTargeting_NormalizeAndValidate_RejectsEmptyGroup(t *testing.T) {
targeting := AnnouncementTargeting{
AnyOf: []AnnouncementConditionGroup{
{AllOf: nil},
},
}
_, err := targeting.NormalizeAndValidate()
require.Error(t, err)
require.ErrorIs(t, err, ErrAnnouncementInvalidTarget)
}
func TestAnnouncementTargeting_NormalizeAndValidate_RejectsInvalidCondition(t *testing.T) {
targeting := AnnouncementTargeting{
AnyOf: []AnnouncementConditionGroup{
{
AllOf: []AnnouncementCondition{
{Type: "balance", Operator: "between", Value: 10},
},
},
},
}
_, err := targeting.NormalizeAndValidate()
require.Error(t, err)
require.ErrorIs(t, err, ErrAnnouncementInvalidTarget)
}
func TestAnnouncementTargeting_Matches_AndOrSemantics(t *testing.T) {
targeting := AnnouncementTargeting{
AnyOf: []AnnouncementConditionGroup{
{
AllOf: []AnnouncementCondition{
{Type: AnnouncementConditionTypeBalance, Operator: AnnouncementOperatorGTE, Value: 100},
{Type: AnnouncementConditionTypeSubscription, Operator: AnnouncementOperatorIn, GroupIDs: []int64{10}},
},
},
{
AllOf: []AnnouncementCondition{
{Type: AnnouncementConditionTypeBalance, Operator: AnnouncementOperatorLT, Value: 5},
},
},
},
}
// 命中第 2 组balance < 5
require.True(t, targeting.Matches(4.99, nil))
require.False(t, targeting.Matches(5, nil))
// 命中第 1 组balance >= 100 AND 订阅 in [10]
require.False(t, targeting.Matches(100, map[int64]struct{}{}))
require.False(t, targeting.Matches(99.9, map[int64]struct{}{10: {}}))
require.True(t, targeting.Matches(100, map[int64]struct{}{10: {}}))
}

View File

@@ -1,66 +1,68 @@
package service
import "github.com/Wei-Shaw/sub2api/internal/domain"
// Status constants
const (
StatusActive = "active"
StatusDisabled = "disabled"
StatusError = "error"
StatusUnused = "unused"
StatusUsed = "used"
StatusExpired = "expired"
StatusActive = domain.StatusActive
StatusDisabled = domain.StatusDisabled
StatusError = domain.StatusError
StatusUnused = domain.StatusUnused
StatusUsed = domain.StatusUsed
StatusExpired = domain.StatusExpired
)
// Role constants
const (
RoleAdmin = "admin"
RoleUser = "user"
RoleAdmin = domain.RoleAdmin
RoleUser = domain.RoleUser
)
// Platform constants
const (
PlatformAnthropic = "anthropic"
PlatformOpenAI = "openai"
PlatformGemini = "gemini"
PlatformAntigravity = "antigravity"
PlatformAnthropic = domain.PlatformAnthropic
PlatformOpenAI = domain.PlatformOpenAI
PlatformGemini = domain.PlatformGemini
PlatformAntigravity = domain.PlatformAntigravity
)
// Account type constants
const (
AccountTypeOAuth = "oauth" // OAuth类型账号full scope: profile + inference
AccountTypeSetupToken = "setup-token" // Setup Token类型账号inference only scope
AccountTypeAPIKey = "apikey" // API Key类型账号
AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号full scope: profile + inference
AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号inference only scope
AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号
)
// Redeem type constants
const (
RedeemTypeBalance = "balance"
RedeemTypeConcurrency = "concurrency"
RedeemTypeSubscription = "subscription"
RedeemTypeBalance = domain.RedeemTypeBalance
RedeemTypeConcurrency = domain.RedeemTypeConcurrency
RedeemTypeSubscription = domain.RedeemTypeSubscription
)
// PromoCode status constants
const (
PromoCodeStatusActive = "active"
PromoCodeStatusDisabled = "disabled"
PromoCodeStatusActive = domain.PromoCodeStatusActive
PromoCodeStatusDisabled = domain.PromoCodeStatusDisabled
)
// Admin adjustment type constants
const (
AdjustmentTypeAdminBalance = "admin_balance" // 管理员调整余额
AdjustmentTypeAdminConcurrency = "admin_concurrency" // 管理员调整并发数
AdjustmentTypeAdminBalance = domain.AdjustmentTypeAdminBalance // 管理员调整余额
AdjustmentTypeAdminConcurrency = domain.AdjustmentTypeAdminConcurrency // 管理员调整并发数
)
// Group subscription type constants
const (
SubscriptionTypeStandard = "standard" // 标准计费模式(按余额扣费)
SubscriptionTypeSubscription = "subscription" // 订阅模式(按限额控制)
SubscriptionTypeStandard = domain.SubscriptionTypeStandard // 标准计费模式(按余额扣费)
SubscriptionTypeSubscription = domain.SubscriptionTypeSubscription // 订阅模式(按限额控制)
)
// Subscription status constants
const (
SubscriptionStatusActive = "active"
SubscriptionStatusExpired = "expired"
SubscriptionStatusSuspended = "suspended"
SubscriptionStatusActive = domain.SubscriptionStatusActive
SubscriptionStatusExpired = domain.SubscriptionStatusExpired
SubscriptionStatusSuspended = domain.SubscriptionStatusSuspended
)
// LinuxDoConnectSyntheticEmailDomain 是 LinuxDo Connect 用户的合成邮箱后缀RFC 保留域名)。

View File

@@ -226,6 +226,7 @@ var ProviderSet = wire.NewSet(
ProvidePricingService,
NewBillingService,
NewBillingCacheService,
NewAnnouncementService,
NewAdminService,
NewGatewayService,
NewOpenAIGatewayService,