Files
sub2api/backend/internal/service/user_service.go
2025-12-18 13:50:39 +08:00

178 lines
4.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"context"
"errors"
"fmt"
"sub2api/internal/config"
"sub2api/internal/model"
"sub2api/internal/repository"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
var (
ErrUserNotFound = errors.New("user not found")
ErrPasswordIncorrect = errors.New("current password is incorrect")
ErrInsufficientPerms = errors.New("insufficient permissions")
)
// UpdateProfileRequest 更新用户资料请求
type UpdateProfileRequest struct {
Email *string `json:"email"`
Concurrency *int `json:"concurrency"`
}
// ChangePasswordRequest 修改密码请求
type ChangePasswordRequest struct {
CurrentPassword string `json:"current_password"`
NewPassword string `json:"new_password"`
}
// UserService 用户服务
type UserService struct {
userRepo *repository.UserRepository
cfg *config.Config
}
// NewUserService 创建用户服务实例
func NewUserService(userRepo *repository.UserRepository, cfg *config.Config) *UserService {
return &UserService{
userRepo: userRepo,
cfg: cfg,
}
}
// GetProfile 获取用户资料
func (s *UserService) GetProfile(ctx context.Context, userID int64) (*model.User, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err)
}
return user, nil
}
// UpdateProfile 更新用户资料
func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req UpdateProfileRequest) (*model.User, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err)
}
// 更新字段
if req.Email != nil {
// 检查新邮箱是否已被使用
exists, err := s.userRepo.ExistsByEmail(ctx, *req.Email)
if err != nil {
return nil, fmt.Errorf("check email exists: %w", err)
}
if exists && *req.Email != user.Email {
return nil, ErrEmailExists
}
user.Email = *req.Email
}
if req.Concurrency != nil {
user.Concurrency = *req.Concurrency
}
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, fmt.Errorf("update user: %w", err)
}
return user, nil
}
// ChangePassword 修改密码
func (s *UserService) ChangePassword(ctx context.Context, userID int64, req ChangePasswordRequest) error {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrUserNotFound
}
return fmt.Errorf("get user: %w", err)
}
// 验证当前密码
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.CurrentPassword)); err != nil {
return ErrPasswordIncorrect
}
// 生成新密码哈希
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.NewPassword), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("hash password: %w", err)
}
user.PasswordHash = string(hashedPassword)
if err := s.userRepo.Update(ctx, user); err != nil {
return fmt.Errorf("update user: %w", err)
}
return nil
}
// GetByID 根据ID获取用户管理员功能
func (s *UserService) GetByID(ctx context.Context, id int64) (*model.User, error) {
user, err := s.userRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err)
}
return user, nil
}
// List 获取用户列表(管理员功能)
func (s *UserService) List(ctx context.Context, params repository.PaginationParams) ([]model.User, *repository.PaginationResult, error) {
users, pagination, err := s.userRepo.List(ctx, params)
if err != nil {
return nil, nil, fmt.Errorf("list users: %w", err)
}
return users, pagination, nil
}
// UpdateBalance 更新用户余额(管理员功能)
func (s *UserService) UpdateBalance(ctx context.Context, userID int64, amount float64) error {
if err := s.userRepo.UpdateBalance(ctx, userID, amount); err != nil {
return fmt.Errorf("update balance: %w", err)
}
return nil
}
// UpdateStatus 更新用户状态(管理员功能)
func (s *UserService) UpdateStatus(ctx context.Context, userID int64, status string) error {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrUserNotFound
}
return fmt.Errorf("get user: %w", err)
}
user.Status = status
if err := s.userRepo.Update(ctx, user); err != nil {
return fmt.Errorf("update user: %w", err)
}
return nil
}
// Delete 删除用户(管理员功能)
func (s *UserService) Delete(ctx context.Context, userID int64) error {
if err := s.userRepo.Delete(ctx, userID); err != nil {
return fmt.Errorf("delete user: %w", err)
}
return nil
}