From 46cb5aaaae0cd40f729fd41a39c0c9a232b484c0 Mon Sep 17 00:00:00 2001 From: Jiali Duan Date: Thu, 20 Oct 2022 16:05:22 -0700 Subject: [PATCH] Omit _check_valid_rotation_matrix by default Summary: According to the profiler trace D40326775, _check_valid_rotation_matrix is slow because of aten::all_close operation and _safe_det_3x3 bottlenecks. Disable the check by default unless environment variable PYTORCH3D_CHECK_ROTATION_MATRICES is set to 1. Comparison after applying the change: ``` Profiling/Function get_world_to_view (ms) Transform_points(ms) specular(ms) before 12.751 18.577 21.384 after 4.432 (34.7%) 9.248 (49.8%) 11.507 (53.8%) ``` Profiling trace: https://pxl.cl/2h687 More details in https://docs.google.com/document/d/1kfhEQfpeQToikr5OH9ZssM39CskxWoJ2p8DO5-t6eWk/edit?usp=sharing Reviewed By: kjchalup Differential Revision: D40442503 fbshipit-source-id: 954b58de47de235c9d93af441643c22868b547d0 --- pytorch3d/transforms/transform3d.py | 6 +++++- tests/test_transforms.py | 23 +++++++++++++++++++++-- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/pytorch3d/transforms/transform3d.py b/pytorch3d/transforms/transform3d.py index f992db4d8..681742de0 100644 --- a/pytorch3d/transforms/transform3d.py +++ b/pytorch3d/transforms/transform3d.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import math +import os import warnings from typing import List, Optional, Union @@ -636,7 +637,10 @@ def __init__( msg = "R must have shape (3, 3) or (N, 3, 3); got %s" raise ValueError(msg % repr(R.shape)) R = R.to(device=device_, dtype=dtype) - _check_valid_rotation_matrix(R, tol=orthogonal_tol) + if os.environ.get("PYTORCH3D_CHECK_ROTATION_MATRICES", "0") == "1": + # Note: aten::all_close in the check is computationally slow, so we + # only run the check when PYTORCH3D_CHECK_ROTATION_MATRICES is on. + _check_valid_rotation_matrix(R, tol=orthogonal_tol) N = R.shape[0] mat = torch.eye(4, dtype=dtype, device=device_) mat = mat.view(1, 4, 4).repeat(N, 1, 1) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 3d8ebd628..bfd67febe 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -4,9 +4,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - import math +import os import unittest +from unittest import mock import torch from pytorch3d.transforms import random_rotations @@ -191,7 +192,25 @@ def test_translate(self): self.assertTrue(torch.allclose(points_out, points_out_expected)) self.assertTrue(torch.allclose(normals_out, normals_out_expected)) - def test_rotate(self): + @mock.patch.dict(os.environ, {"PYTORCH3D_CHECK_ROTATION_MATRICES": "1"}, clear=True) + def test_rotate_check_rot_valid_on(self): + R = so3_exp_map(torch.randn((1, 3))) + t = Transform3d().rotate(R) + points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view( + 1, 3, 3 + ) + normals = torch.tensor( + [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 1.0, 0.0]] + ).view(1, 3, 3) + points_out = t.transform_points(points) + normals_out = t.transform_normals(normals) + points_out_expected = torch.bmm(points, R) + normals_out_expected = torch.bmm(normals, R) + self.assertTrue(torch.allclose(points_out, points_out_expected)) + self.assertTrue(torch.allclose(normals_out, normals_out_expected)) + + @mock.patch.dict(os.environ, {"PYTORCH3D_CHECK_ROTATION_MATRICES": "0"}, clear=True) + def test_rotate_check_rot_valid_off(self): R = so3_exp_map(torch.randn((1, 3))) t = Transform3d().rotate(R) points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view(