package service import ( "context" "errors" "fmt" "sub2api/internal/config" "sub2api/internal/model" "sub2api/internal/repository" "golang.org/x/crypto/bcrypt" "gorm.io/gorm" ) var ( ErrUserNotFound = errors.New("user not found") ErrPasswordIncorrect = errors.New("current password is incorrect") ErrInsufficientPerms = errors.New("insufficient permissions") ) // UpdateProfileRequest 更新用户资料请求 type UpdateProfileRequest struct { Email *string `json:"email"` Concurrency *int `json:"concurrency"` } // ChangePasswordRequest 修改密码请求 type ChangePasswordRequest struct { CurrentPassword string `json:"current_password"` NewPassword string `json:"new_password"` } // UserService 用户服务 type UserService struct { userRepo *repository.UserRepository cfg *config.Config } // NewUserService 创建用户服务实例 func NewUserService(userRepo *repository.UserRepository, cfg *config.Config) *UserService { return &UserService{ userRepo: userRepo, cfg: cfg, } } // GetProfile 获取用户资料 func (s *UserService) GetProfile(ctx context.Context, userID int64) (*model.User, error) { user, err := s.userRepo.GetByID(ctx, userID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrUserNotFound } return nil, fmt.Errorf("get user: %w", err) } return user, nil } // UpdateProfile 更新用户资料 func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req UpdateProfileRequest) (*model.User, error) { user, err := s.userRepo.GetByID(ctx, userID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrUserNotFound } return nil, fmt.Errorf("get user: %w", err) } // 更新字段 if req.Email != nil { // 检查新邮箱是否已被使用 exists, err := s.userRepo.ExistsByEmail(ctx, *req.Email) if err != nil { return nil, fmt.Errorf("check email exists: %w", err) } if exists && *req.Email != user.Email { return nil, ErrEmailExists } user.Email = *req.Email } if req.Concurrency != nil { user.Concurrency = *req.Concurrency } if err := s.userRepo.Update(ctx, user); err != nil { return nil, fmt.Errorf("update user: %w", err) } return user, nil } // ChangePassword 修改密码 func (s *UserService) ChangePassword(ctx context.Context, userID int64, req ChangePasswordRequest) error { user, err := s.userRepo.GetByID(ctx, userID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return ErrUserNotFound } return fmt.Errorf("get user: %w", err) } // 验证当前密码 if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.CurrentPassword)); err != nil { return ErrPasswordIncorrect } // 生成新密码哈希 hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.NewPassword), bcrypt.DefaultCost) if err != nil { return fmt.Errorf("hash password: %w", err) } user.PasswordHash = string(hashedPassword) if err := s.userRepo.Update(ctx, user); err != nil { return fmt.Errorf("update user: %w", err) } return nil } // GetByID 根据ID获取用户(管理员功能) func (s *UserService) GetByID(ctx context.Context, id int64) (*model.User, error) { user, err := s.userRepo.GetByID(ctx, id) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrUserNotFound } return nil, fmt.Errorf("get user: %w", err) } return user, nil } // List 获取用户列表(管理员功能) func (s *UserService) List(ctx context.Context, params repository.PaginationParams) ([]model.User, *repository.PaginationResult, error) { users, pagination, err := s.userRepo.List(ctx, params) if err != nil { return nil, nil, fmt.Errorf("list users: %w", err) } return users, pagination, nil } // UpdateBalance 更新用户余额(管理员功能) func (s *UserService) UpdateBalance(ctx context.Context, userID int64, amount float64) error { if err := s.userRepo.UpdateBalance(ctx, userID, amount); err != nil { return fmt.Errorf("update balance: %w", err) } return nil } // UpdateStatus 更新用户状态(管理员功能) func (s *UserService) UpdateStatus(ctx context.Context, userID int64, status string) error { user, err := s.userRepo.GetByID(ctx, userID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return ErrUserNotFound } return fmt.Errorf("get user: %w", err) } user.Status = status if err := s.userRepo.Update(ctx, user); err != nil { return fmt.Errorf("update user: %w", err) } return nil } // Delete 删除用户(管理员功能) func (s *UserService) Delete(ctx context.Context, userID int64) error { if err := s.userRepo.Delete(ctx, userID); err != nil { return fmt.Errorf("delete user: %w", err) } return nil }