refactor: 自定义业务错误

This commit is contained in:
Forest
2025-12-25 20:52:47 +08:00
parent f51ad2e126
commit eeaff85e47
60 changed files with 1222 additions and 622 deletions

View File

@@ -2,56 +2,61 @@ package repository
import (
"context"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm"
)
type UserRepository struct {
type userRepository struct {
db *gorm.DB
}
func NewUserRepository(db *gorm.DB) *UserRepository {
return &UserRepository{db: db}
func NewUserRepository(db *gorm.DB) service.UserRepository {
return &userRepository{db: db}
}
func (r *UserRepository) Create(ctx context.Context, user *model.User) error {
return r.db.WithContext(ctx).Create(user).Error
func (r *userRepository) Create(ctx context.Context, user *model.User) error {
err := r.db.WithContext(ctx).Create(user).Error
return translatePersistenceError(err, nil, service.ErrEmailExists)
}
func (r *UserRepository) GetByID(ctx context.Context, id int64) (*model.User, error) {
func (r *userRepository) GetByID(ctx context.Context, id int64) (*model.User, error) {
var user model.User
err := r.db.WithContext(ctx).First(&user, id).Error
if err != nil {
return nil, err
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
}
return &user, nil
}
func (r *UserRepository) GetByEmail(ctx context.Context, email string) (*model.User, error) {
func (r *userRepository) GetByEmail(ctx context.Context, email string) (*model.User, error) {
var user model.User
err := r.db.WithContext(ctx).Where("email = ?", email).First(&user).Error
if err != nil {
return nil, err
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
}
return &user, nil
}
func (r *UserRepository) Update(ctx context.Context, user *model.User) error {
return r.db.WithContext(ctx).Save(user).Error
func (r *userRepository) Update(ctx context.Context, user *model.User) error {
err := r.db.WithContext(ctx).Save(user).Error
return translatePersistenceError(err, nil, service.ErrEmailExists)
}
func (r *UserRepository) Delete(ctx context.Context, id int64) error {
func (r *userRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.User{}, id).Error
}
func (r *UserRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error) {
func (r *userRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "")
}
// ListWithFilters lists users with optional filtering by status, role, and search query
func (r *UserRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]model.User, *pagination.PaginationResult, error) {
func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]model.User, *pagination.PaginationResult, error) {
var users []model.User
var total int64
@@ -120,13 +125,13 @@ func (r *UserRepository) ListWithFilters(ctx context.Context, params pagination.
}, nil
}
func (r *UserRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id).
Update("balance", gorm.Expr("balance + ?", amount)).Error
}
// DeductBalance 扣减用户余额,仅当余额充足时执行
func (r *UserRepository) DeductBalance(ctx context.Context, id int64, amount float64) error {
func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount float64) error {
result := r.db.WithContext(ctx).Model(&model.User{}).
Where("id = ? AND balance >= ?", id, amount).
Update("balance", gorm.Expr("balance - ?", amount))
@@ -134,17 +139,17 @@ func (r *UserRepository) DeductBalance(ctx context.Context, id int64, amount flo
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound // 余额不足或用户不存在
return service.ErrInsufficientBalance
}
return nil
}
func (r *UserRepository) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id).
Update("concurrency", gorm.Expr("concurrency + ?", amount)).Error
}
func (r *UserRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.User{}).Where("email = ?", email).Count(&count).Error
return count > 0, err
@@ -152,7 +157,7 @@ func (r *UserRepository) ExistsByEmail(ctx context.Context, email string) (bool,
// RemoveGroupFromAllowedGroups 从所有用户的 allowed_groups 数组中移除指定的分组ID
// 使用 PostgreSQL 的 array_remove 函数
func (r *UserRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Model(&model.User{}).
Where("? = ANY(allowed_groups)", groupID).
Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", groupID))
@@ -160,14 +165,14 @@ func (r *UserRepository) RemoveGroupFromAllowedGroups(ctx context.Context, group
}
// GetFirstAdmin 获取第一个管理员用户(用于 Admin API Key 认证)
func (r *UserRepository) GetFirstAdmin(ctx context.Context) (*model.User, error) {
func (r *userRepository) GetFirstAdmin(ctx context.Context) (*model.User, error) {
var user model.User
err := r.db.WithContext(ctx).
Where("role = ? AND status = ?", model.RoleAdmin, model.StatusActive).
Order("id ASC").
First(&user).Error
if err != nil {
return nil, err
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
}
return &user, nil
}