84 lines
2.2 KiB
Go
84 lines
2.2 KiB
Go
package repository
|
|
|
|
import (
|
|
"context"
|
|
"time"
|
|
|
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
|
"github.com/Wei-Shaw/sub2api/ent/announcementread"
|
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
|
)
|
|
|
|
type announcementReadRepository struct {
|
|
client *dbent.Client
|
|
}
|
|
|
|
func NewAnnouncementReadRepository(client *dbent.Client) service.AnnouncementReadRepository {
|
|
return &announcementReadRepository{client: client}
|
|
}
|
|
|
|
func (r *announcementReadRepository) MarkRead(ctx context.Context, announcementID, userID int64, readAt time.Time) error {
|
|
client := clientFromContext(ctx, r.client)
|
|
return client.AnnouncementRead.Create().
|
|
SetAnnouncementID(announcementID).
|
|
SetUserID(userID).
|
|
SetReadAt(readAt).
|
|
OnConflictColumns(announcementread.FieldAnnouncementID, announcementread.FieldUserID).
|
|
DoNothing().
|
|
Exec(ctx)
|
|
}
|
|
|
|
func (r *announcementReadRepository) GetReadMapByUser(ctx context.Context, userID int64, announcementIDs []int64) (map[int64]time.Time, error) {
|
|
if len(announcementIDs) == 0 {
|
|
return map[int64]time.Time{}, nil
|
|
}
|
|
|
|
rows, err := r.client.AnnouncementRead.Query().
|
|
Where(
|
|
announcementread.UserIDEQ(userID),
|
|
announcementread.AnnouncementIDIn(announcementIDs...),
|
|
).
|
|
All(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
out := make(map[int64]time.Time, len(rows))
|
|
for i := range rows {
|
|
out[rows[i].AnnouncementID] = rows[i].ReadAt
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
func (r *announcementReadRepository) GetReadMapByUsers(ctx context.Context, announcementID int64, userIDs []int64) (map[int64]time.Time, error) {
|
|
if len(userIDs) == 0 {
|
|
return map[int64]time.Time{}, nil
|
|
}
|
|
|
|
rows, err := r.client.AnnouncementRead.Query().
|
|
Where(
|
|
announcementread.AnnouncementIDEQ(announcementID),
|
|
announcementread.UserIDIn(userIDs...),
|
|
).
|
|
All(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
out := make(map[int64]time.Time, len(rows))
|
|
for i := range rows {
|
|
out[rows[i].UserID] = rows[i].ReadAt
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
func (r *announcementReadRepository) CountByAnnouncementID(ctx context.Context, announcementID int64) (int64, error) {
|
|
count, err := r.client.AnnouncementRead.Query().
|
|
Where(announcementread.AnnouncementIDEQ(announcementID)).
|
|
Count(ctx)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return int64(count), nil
|
|
}
|