258 lines
8.0 KiB
Go
258 lines
8.0 KiB
Go
package repository
|
|
|
|
import (
|
|
"context"
|
|
"time"
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
|
|
|
"github.com/lib/pq"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
type userRepository struct {
|
|
db *gorm.DB
|
|
}
|
|
|
|
func NewUserRepository(db *gorm.DB) service.UserRepository {
|
|
return &userRepository{db: db}
|
|
}
|
|
|
|
func (r *userRepository) Create(ctx context.Context, user *service.User) error {
|
|
m := userModelFromService(user)
|
|
err := r.db.WithContext(ctx).Create(m).Error
|
|
if err == nil {
|
|
applyUserModelToService(user, m)
|
|
}
|
|
return translatePersistenceError(err, nil, service.ErrEmailExists)
|
|
}
|
|
|
|
func (r *userRepository) GetByID(ctx context.Context, id int64) (*service.User, error) {
|
|
var m userModel
|
|
err := r.db.WithContext(ctx).First(&m, id).Error
|
|
if err != nil {
|
|
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
|
|
}
|
|
return userModelToService(&m), nil
|
|
}
|
|
|
|
func (r *userRepository) GetByEmail(ctx context.Context, email string) (*service.User, error) {
|
|
var m userModel
|
|
err := r.db.WithContext(ctx).Where("email = ?", email).First(&m).Error
|
|
if err != nil {
|
|
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
|
|
}
|
|
return userModelToService(&m), nil
|
|
}
|
|
|
|
func (r *userRepository) Update(ctx context.Context, user *service.User) error {
|
|
m := userModelFromService(user)
|
|
err := r.db.WithContext(ctx).Save(m).Error
|
|
if err == nil {
|
|
applyUserModelToService(user, m)
|
|
}
|
|
return translatePersistenceError(err, nil, service.ErrEmailExists)
|
|
}
|
|
|
|
func (r *userRepository) Delete(ctx context.Context, id int64) error {
|
|
return r.db.WithContext(ctx).Delete(&userModel{}, id).Error
|
|
}
|
|
|
|
func (r *userRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
|
|
return r.ListWithFilters(ctx, params, "", "", "")
|
|
}
|
|
|
|
// ListWithFilters lists users with optional filtering by status, role, and search query
|
|
func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]service.User, *pagination.PaginationResult, error) {
|
|
var users []userModel
|
|
var total int64
|
|
|
|
db := r.db.WithContext(ctx).Model(&userModel{})
|
|
|
|
// Apply filters
|
|
if status != "" {
|
|
db = db.Where("status = ?", status)
|
|
}
|
|
if role != "" {
|
|
db = db.Where("role = ?", role)
|
|
}
|
|
if search != "" {
|
|
searchPattern := "%" + search + "%"
|
|
db = db.Where(
|
|
"email ILIKE ? OR username ILIKE ? OR wechat ILIKE ?",
|
|
searchPattern, searchPattern, searchPattern,
|
|
)
|
|
}
|
|
|
|
if err := db.Count(&total).Error; err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
// Query users with pagination (reuse the same db with filters applied)
|
|
if err := db.Offset(params.Offset()).Limit(params.Limit()).Order("id DESC").Find(&users).Error; err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
// Batch load subscriptions for all users (avoid N+1)
|
|
if len(users) > 0 {
|
|
userIDs := make([]int64, len(users))
|
|
userMap := make(map[int64]*service.User, len(users))
|
|
outUsers := make([]service.User, 0, len(users))
|
|
for i := range users {
|
|
userIDs[i] = users[i].ID
|
|
u := userModelToService(&users[i])
|
|
outUsers = append(outUsers, *u)
|
|
userMap[u.ID] = &outUsers[len(outUsers)-1]
|
|
}
|
|
|
|
// Query active subscriptions with groups in one query
|
|
var subscriptions []userSubscriptionModel
|
|
if err := r.db.WithContext(ctx).
|
|
Preload("Group").
|
|
Where("user_id IN ? AND status = ?", userIDs, service.SubscriptionStatusActive).
|
|
Find(&subscriptions).Error; err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
// Associate subscriptions with users
|
|
for i := range subscriptions {
|
|
if user, ok := userMap[subscriptions[i].UserID]; ok {
|
|
user.Subscriptions = append(user.Subscriptions, *userSubscriptionModelToService(&subscriptions[i]))
|
|
}
|
|
}
|
|
|
|
return outUsers, paginationResultFromTotal(total, params), nil
|
|
}
|
|
|
|
outUsers := make([]service.User, 0, len(users))
|
|
for i := range users {
|
|
outUsers = append(outUsers, *userModelToService(&users[i]))
|
|
}
|
|
|
|
return outUsers, paginationResultFromTotal(total, params), nil
|
|
}
|
|
|
|
func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
|
|
return r.db.WithContext(ctx).Model(&userModel{}).Where("id = ?", id).
|
|
Update("balance", gorm.Expr("balance + ?", amount)).Error
|
|
}
|
|
|
|
// DeductBalance 扣减用户余额,仅当余额充足时执行
|
|
func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount float64) error {
|
|
result := r.db.WithContext(ctx).Model(&userModel{}).
|
|
Where("id = ? AND balance >= ?", id, amount).
|
|
Update("balance", gorm.Expr("balance - ?", amount))
|
|
if result.Error != nil {
|
|
return result.Error
|
|
}
|
|
if result.RowsAffected == 0 {
|
|
return service.ErrInsufficientBalance
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
|
|
return r.db.WithContext(ctx).Model(&userModel{}).Where("id = ?", id).
|
|
Update("concurrency", gorm.Expr("concurrency + ?", amount)).Error
|
|
}
|
|
|
|
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
|
|
var count int64
|
|
err := r.db.WithContext(ctx).Model(&userModel{}).Where("email = ?", email).Count(&count).Error
|
|
return count > 0, err
|
|
}
|
|
|
|
// RemoveGroupFromAllowedGroups 从所有用户的 allowed_groups 数组中移除指定的分组ID
|
|
// 使用 PostgreSQL 的 array_remove 函数
|
|
func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
|
|
result := r.db.WithContext(ctx).Model(&userModel{}).
|
|
Where("? = ANY(allowed_groups)", groupID).
|
|
Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", groupID))
|
|
return result.RowsAffected, result.Error
|
|
}
|
|
|
|
// GetFirstAdmin 获取第一个管理员用户(用于 Admin API Key 认证)
|
|
func (r *userRepository) GetFirstAdmin(ctx context.Context) (*service.User, error) {
|
|
var m userModel
|
|
err := r.db.WithContext(ctx).
|
|
Where("role = ? AND status = ?", service.RoleAdmin, service.StatusActive).
|
|
Order("id ASC").
|
|
First(&m).Error
|
|
if err != nil {
|
|
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
|
|
}
|
|
return userModelToService(&m), nil
|
|
}
|
|
|
|
type userModel struct {
|
|
ID int64 `gorm:"primaryKey"`
|
|
Email string `gorm:"uniqueIndex;size:255;not null"`
|
|
Username string `gorm:"size:100;default:''"`
|
|
Wechat string `gorm:"size:100;default:''"`
|
|
Notes string `gorm:"type:text;default:''"`
|
|
PasswordHash string `gorm:"size:255;not null"`
|
|
Role string `gorm:"size:20;default:user;not null"`
|
|
Balance float64 `gorm:"type:decimal(20,8);default:0;not null"`
|
|
Concurrency int `gorm:"default:5;not null"`
|
|
Status string `gorm:"size:20;default:active;not null"`
|
|
AllowedGroups pq.Int64Array `gorm:"type:bigint[]"`
|
|
CreatedAt time.Time `gorm:"not null"`
|
|
UpdatedAt time.Time `gorm:"not null"`
|
|
DeletedAt gorm.DeletedAt `gorm:"index"`
|
|
}
|
|
|
|
func (userModel) TableName() string { return "users" }
|
|
|
|
func userModelToService(m *userModel) *service.User {
|
|
if m == nil {
|
|
return nil
|
|
}
|
|
return &service.User{
|
|
ID: m.ID,
|
|
Email: m.Email,
|
|
Username: m.Username,
|
|
Wechat: m.Wechat,
|
|
Notes: m.Notes,
|
|
PasswordHash: m.PasswordHash,
|
|
Role: m.Role,
|
|
Balance: m.Balance,
|
|
Concurrency: m.Concurrency,
|
|
Status: m.Status,
|
|
AllowedGroups: []int64(m.AllowedGroups),
|
|
CreatedAt: m.CreatedAt,
|
|
UpdatedAt: m.UpdatedAt,
|
|
}
|
|
}
|
|
|
|
func userModelFromService(u *service.User) *userModel {
|
|
if u == nil {
|
|
return nil
|
|
}
|
|
return &userModel{
|
|
ID: u.ID,
|
|
Email: u.Email,
|
|
Username: u.Username,
|
|
Wechat: u.Wechat,
|
|
Notes: u.Notes,
|
|
PasswordHash: u.PasswordHash,
|
|
Role: u.Role,
|
|
Balance: u.Balance,
|
|
Concurrency: u.Concurrency,
|
|
Status: u.Status,
|
|
AllowedGroups: pq.Int64Array(u.AllowedGroups),
|
|
CreatedAt: u.CreatedAt,
|
|
UpdatedAt: u.UpdatedAt,
|
|
}
|
|
}
|
|
|
|
func applyUserModelToService(dst *service.User, src *userModel) {
|
|
if dst == nil || src == nil {
|
|
return
|
|
}
|
|
dst.ID = src.ID
|
|
dst.CreatedAt = src.CreatedAt
|
|
dst.UpdatedAt = src.UpdatedAt
|
|
}
|