From ce3fce49d7ad1a680d8c9be660164d5f7a0bb976 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C3=ABl=20Ramamonjisoa?= Date: Thu, 22 Sep 2022 12:51:37 -0700 Subject: [PATCH] Adding a Checkerboard mesh utility to Pytorch3d Summary: Adding a checkerboard mesh utility to Pytorch3d. Reviewed By: bottler Differential Revision: D39718916 fbshipit-source-id: d43cd30e566b5db068bae6eed0388057634428c8 --- pytorch3d/utils/__init__.py | 1 + pytorch3d/utils/checkerboard.py | 89 +++++++++++++++++++++++++++++++++ tests/test_checkerboard.py | 21 ++++++++ 3 files changed, 111 insertions(+) create mode 100644 pytorch3d/utils/checkerboard.py create mode 100644 tests/test_checkerboard.py diff --git a/pytorch3d/utils/__init__.py b/pytorch3d/utils/__init__.py index a9ec1581e..f3681e823 100644 --- a/pytorch3d/utils/__init__.py +++ b/pytorch3d/utils/__init__.py @@ -10,6 +10,7 @@ pulsar_from_cameras_projection, pulsar_from_opencv_projection, ) +from .checkerboard import checkerboard from .ico_sphere import ico_sphere from .torus import torus diff --git a/pytorch3d/utils/checkerboard.py b/pytorch3d/utils/checkerboard.py new file mode 100644 index 000000000..625c08684 --- /dev/null +++ b/pytorch3d/utils/checkerboard.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Optional, Tuple + +import torch +from pytorch3d.common.compat import meshgrid_ij +from pytorch3d.renderer.mesh.textures import TexturesAtlas +from pytorch3d.structures.meshes import Meshes + + +def checkerboard( + radius: int = 4, + color1: Tuple[float, ...] = (0.0, 0.0, 0.0), + color2: Tuple[float, ...] = (1.0, 1.0, 1.0), + device: Optional[torch.types._device] = None, +) -> Meshes: + """ + Returns a mesh of squares in the xy-plane where each unit is one of the two given + colors and adjacent squares have opposite colors. + Args: + radius: how many squares in each direction from the origin + color1: background color + color2: foreground color (must have the same number of channels as color1) + Returns: + new Meshes object containing one mesh. + """ + + if device is None: + device = torch.device("cpu") + if radius < 1: + raise ValueError("radius must be > 0") + + num_verts_per_row = 2 * radius + 1 + + # construct 2D grid of 3D vertices + x = torch.arange(-radius, radius + 1, device=device) + grid_y, grid_x = meshgrid_ij(x, x) + verts = torch.stack( + [grid_x, grid_y, torch.zeros((2 * radius + 1, 2 * radius + 1))], dim=-1 + ) + verts = verts.view(1, -1, 3) + + top_triangle_idx = torch.arange(0, num_verts_per_row * (num_verts_per_row - 1)) + top_triangle_idx = torch.stack( + [ + top_triangle_idx, + top_triangle_idx + 1, + top_triangle_idx + num_verts_per_row + 1, + ], + dim=-1, + ) + + bottom_triangle_idx = top_triangle_idx[:, [0, 2, 1]] + torch.tensor( + [0, 0, num_verts_per_row - 1] + ) + + faces = torch.zeros( + (1, len(top_triangle_idx) + len(bottom_triangle_idx), 3), + dtype=torch.long, + device=device, + ) + faces[0, ::2] = top_triangle_idx + faces[0, 1::2] = bottom_triangle_idx + + # construct range of indices that excludes the boundary to avoid wrong triangles + indexing_range = torch.arange(0, 2 * num_verts_per_row * num_verts_per_row).view( + num_verts_per_row, num_verts_per_row, 2 + ) + indexing_range = indexing_range[:-1, :-1] # removes boundaries from list of indices + indexing_range = indexing_range.reshape( + 2 * (num_verts_per_row - 1) * (num_verts_per_row - 1) + ) + + faces = faces[:, indexing_range] + + # adding color + colors = torch.tensor(color1).repeat(2 * num_verts_per_row * num_verts_per_row, 1) + colors[2::4] = torch.tensor(color2) + colors[3::4] = torch.tensor(color2) + colors = colors[None, indexing_range, None, None] + + texture_atlas = TexturesAtlas(colors) + + return Meshes(verts=verts, faces=faces, textures=texture_atlas) diff --git a/tests/test_checkerboard.py b/tests/test_checkerboard.py new file mode 100644 index 000000000..7da0dbace --- /dev/null +++ b/tests/test_checkerboard.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from pytorch3d.utils import checkerboard + +from .common_testing import TestCaseMixin + + +class TestCheckerboard(TestCaseMixin, unittest.TestCase): + def test_simple(self): + board = checkerboard(5) + verts = board.verts_packed() + expect = torch.tensor([5.0, 5.0, 0]) + self.assertClose(verts.min(dim=0).values, -expect) + self.assertClose(verts.max(dim=0).values, expect)