From 305cf32f6bfab02f735a938afed29d41d1759e4c Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Wed, 21 Sep 2022 04:29:44 -0700 Subject: [PATCH] Avoid raysampler dict Summary: A significant speedup (e.g. >2% of a forward pass). Move NDCMultinomialRaysampler parts of AbstractMaskRaySampler to members instead of living in a dict. The dict was hiding them from the nn.Module system so their _xy_grid members were remaining on the CPU. Therefore they were being copied to the GPU in every forward pass. (We couldn't easily use a ModuleDict here because the enum keys are not strs.) Reviewed By: shapovalov Differential Revision: D39668589 fbshipit-source-id: 719b88e4a08fd7263a284e0ab38189e666bd7e3a --- .../models/renderer/ray_sampler.py | 64 +++++++++---------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/pytorch3d/implicitron/models/renderer/ray_sampler.py b/pytorch3d/implicitron/models/renderer/ray_sampler.py index b876d906e..76f9f5bcb 100644 --- a/pytorch3d/implicitron/models/renderer/ray_sampler.py +++ b/pytorch3d/implicitron/models/renderer/ray_sampler.py @@ -100,34 +100,32 @@ def __post_init__(self): ), } - self._raysamplers = { - EvaluationMode.TRAINING: NDCMultinomialRaysampler( - image_width=self.image_width, - image_height=self.image_height, - n_pts_per_ray=self.n_pts_per_ray_training, - min_depth=0.0, - max_depth=0.0, - n_rays_per_image=self.n_rays_per_image_sampled_from_mask - if self._sampling_mode[EvaluationMode.TRAINING] - == RenderSamplingMode.MASK_SAMPLE - else None, - unit_directions=True, - stratified_sampling=self.stratified_point_sampling_training, - ), - EvaluationMode.EVALUATION: NDCMultinomialRaysampler( - image_width=self.image_width, - image_height=self.image_height, - n_pts_per_ray=self.n_pts_per_ray_evaluation, - min_depth=0.0, - max_depth=0.0, - n_rays_per_image=self.n_rays_per_image_sampled_from_mask - if self._sampling_mode[EvaluationMode.EVALUATION] - == RenderSamplingMode.MASK_SAMPLE - else None, - unit_directions=True, - stratified_sampling=self.stratified_point_sampling_evaluation, - ), - } + self._training_raysampler = NDCMultinomialRaysampler( + image_width=self.image_width, + image_height=self.image_height, + n_pts_per_ray=self.n_pts_per_ray_training, + min_depth=0.0, + max_depth=0.0, + n_rays_per_image=self.n_rays_per_image_sampled_from_mask + if self._sampling_mode[EvaluationMode.TRAINING] + == RenderSamplingMode.MASK_SAMPLE + else None, + unit_directions=True, + stratified_sampling=self.stratified_point_sampling_training, + ) + self._evaluation_raysampler = NDCMultinomialRaysampler( + image_width=self.image_width, + image_height=self.image_height, + n_pts_per_ray=self.n_pts_per_ray_evaluation, + min_depth=0.0, + max_depth=0.0, + n_rays_per_image=self.n_rays_per_image_sampled_from_mask + if self._sampling_mode[EvaluationMode.EVALUATION] + == RenderSamplingMode.MASK_SAMPLE + else None, + unit_directions=True, + stratified_sampling=self.stratified_point_sampling_evaluation, + ) def _get_min_max_depth_bounds(self, cameras: CamerasBase) -> Tuple[float, float]: raise NotImplementedError() @@ -169,11 +167,13 @@ def forward( min_depth, max_depth = self._get_min_max_depth_bounds(cameras) + raysampler = { + EvaluationMode.TRAINING: self._training_raysampler, + EvaluationMode.EVALUATION: self._evaluation_raysampler, + }[evaluation_mode] + # pyre-fixme[29]: - # `Union[BoundMethod[typing.Callable(torch.Tensor.__getitem__)[[Named(self, - # torch.Tensor), Named(item, typing.Any)], typing.Any], torch.Tensor], - # torch.Tensor, torch.nn.Module]` is not a function. - ray_bundle = self._raysamplers[evaluation_mode]( + ray_bundle = raysampler( cameras=cameras, mask=sample_mask, min_depth=min_depth,