package service import ( "context" "fmt" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" ) var ( ErrUserNotFound = infraerrors.NotFound("USER_NOT_FOUND", "user not found") ErrPasswordIncorrect = infraerrors.BadRequest("PASSWORD_INCORRECT", "current password is incorrect") ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions") ) // UserListFilters contains all filter options for listing users type UserListFilters struct { Status string // User status filter Role string // User role filter Search string // Search in email, username Attributes map[int64]string // Custom attribute filters: attributeID -> value } type UserRepository interface { Create(ctx context.Context, user *User) error GetByID(ctx context.Context, id int64) (*User, error) GetByEmail(ctx context.Context, email string) (*User, error) GetFirstAdmin(ctx context.Context) (*User, error) Update(ctx context.Context, user *User) error Delete(ctx context.Context, id int64) error List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UserListFilters) ([]User, *pagination.PaginationResult, error) UpdateBalance(ctx context.Context, id int64, amount float64) error DeductBalance(ctx context.Context, id int64, amount float64) error UpdateConcurrency(ctx context.Context, id int64, amount int) error ExistsByEmail(ctx context.Context, email string) (bool, error) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) // TOTP 双因素认证 UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error EnableTotp(ctx context.Context, userID int64) error DisableTotp(ctx context.Context, userID int64) error } // UpdateProfileRequest 更新用户资料请求 type UpdateProfileRequest struct { Email *string `json:"email"` Username *string `json:"username"` Concurrency *int `json:"concurrency"` } // ChangePasswordRequest 修改密码请求 type ChangePasswordRequest struct { CurrentPassword string `json:"current_password"` NewPassword string `json:"new_password"` } // UserService 用户服务 type UserService struct { userRepo UserRepository authCacheInvalidator APIKeyAuthCacheInvalidator } // NewUserService 创建用户服务实例 func NewUserService(userRepo UserRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *UserService { return &UserService{ userRepo: userRepo, authCacheInvalidator: authCacheInvalidator, } } // GetFirstAdmin 获取首个管理员用户(用于 Admin API Key 认证) func (s *UserService) GetFirstAdmin(ctx context.Context) (*User, error) { admin, err := s.userRepo.GetFirstAdmin(ctx) if err != nil { return nil, fmt.Errorf("get first admin: %w", err) } return admin, nil } // GetProfile 获取用户资料 func (s *UserService) GetProfile(ctx context.Context, userID int64) (*User, error) { user, err := s.userRepo.GetByID(ctx, userID) if err != nil { return nil, fmt.Errorf("get user: %w", err) } return user, nil } // UpdateProfile 更新用户资料 func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req UpdateProfileRequest) (*User, error) { user, err := s.userRepo.GetByID(ctx, userID) if err != nil { return nil, fmt.Errorf("get user: %w", err) } oldConcurrency := user.Concurrency // 更新字段 if req.Email != nil { // 检查新邮箱是否已被使用 exists, err := s.userRepo.ExistsByEmail(ctx, *req.Email) if err != nil { return nil, fmt.Errorf("check email exists: %w", err) } if exists && *req.Email != user.Email { return nil, ErrEmailExists } user.Email = *req.Email } if req.Username != nil { user.Username = *req.Username } if req.Concurrency != nil { user.Concurrency = *req.Concurrency } if err := s.userRepo.Update(ctx, user); err != nil { return nil, fmt.Errorf("update user: %w", err) } if s.authCacheInvalidator != nil && user.Concurrency != oldConcurrency { s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) } return user, nil } // ChangePassword 修改密码 // Security: Increments TokenVersion to invalidate all existing JWT tokens func (s *UserService) ChangePassword(ctx context.Context, userID int64, req ChangePasswordRequest) error { user, err := s.userRepo.GetByID(ctx, userID) if err != nil { return fmt.Errorf("get user: %w", err) } // 验证当前密码 if !user.CheckPassword(req.CurrentPassword) { return ErrPasswordIncorrect } if err := user.SetPassword(req.NewPassword); err != nil { return fmt.Errorf("set password: %w", err) } // Increment TokenVersion to invalidate all existing tokens // This ensures that any tokens issued before the password change become invalid user.TokenVersion++ if err := s.userRepo.Update(ctx, user); err != nil { return fmt.Errorf("update user: %w", err) } return nil } // GetByID 根据ID获取用户(管理员功能) func (s *UserService) GetByID(ctx context.Context, id int64) (*User, error) { user, err := s.userRepo.GetByID(ctx, id) if err != nil { return nil, fmt.Errorf("get user: %w", err) } return user, nil } // List 获取用户列表(管理员功能) func (s *UserService) List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) { users, pagination, err := s.userRepo.List(ctx, params) if err != nil { return nil, nil, fmt.Errorf("list users: %w", err) } return users, pagination, nil } // UpdateBalance 更新用户余额(管理员功能) func (s *UserService) UpdateBalance(ctx context.Context, userID int64, amount float64) error { if err := s.userRepo.UpdateBalance(ctx, userID, amount); err != nil { return fmt.Errorf("update balance: %w", err) } if s.authCacheInvalidator != nil { s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) } return nil } // UpdateConcurrency 更新用户并发数(管理员功能) func (s *UserService) UpdateConcurrency(ctx context.Context, userID int64, concurrency int) error { if err := s.userRepo.UpdateConcurrency(ctx, userID, concurrency); err != nil { return fmt.Errorf("update concurrency: %w", err) } if s.authCacheInvalidator != nil { s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) } return nil } // UpdateStatus 更新用户状态(管理员功能) func (s *UserService) UpdateStatus(ctx context.Context, userID int64, status string) error { user, err := s.userRepo.GetByID(ctx, userID) if err != nil { return fmt.Errorf("get user: %w", err) } user.Status = status if err := s.userRepo.Update(ctx, user); err != nil { return fmt.Errorf("update user: %w", err) } if s.authCacheInvalidator != nil { s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) } return nil } // Delete 删除用户(管理员功能) func (s *UserService) Delete(ctx context.Context, userID int64) error { if s.authCacheInvalidator != nil { s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) } if err := s.userRepo.Delete(ctx, userID); err != nil { return fmt.Errorf("delete user: %w", err) } return nil }