Skip to content

Commit

Permalink
Adding a Checkerboard mesh utility to Pytorch3d
Browse files Browse the repository at this point in the history
Summary: Adding a checkerboard mesh utility to Pytorch3d.

Reviewed By: bottler

Differential Revision: D39718916

fbshipit-source-id: d43cd30e566b5db068bae6eed0388057634428c8
  • Loading branch information
micramamonjisoa authored and facebook-github-bot committed Sep 22, 2022
1 parent f34da3d commit ce3fce4
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 0 deletions.
1 change: 1 addition & 0 deletions pytorch3d/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
pulsar_from_cameras_projection,
pulsar_from_opencv_projection,
)
from .checkerboard import checkerboard
from .ico_sphere import ico_sphere
from .torus import torus

Expand Down
89 changes: 89 additions & 0 deletions pytorch3d/utils/checkerboard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


from typing import Optional, Tuple

import torch
from pytorch3d.common.compat import meshgrid_ij
from pytorch3d.renderer.mesh.textures import TexturesAtlas
from pytorch3d.structures.meshes import Meshes


def checkerboard(
radius: int = 4,
color1: Tuple[float, ...] = (0.0, 0.0, 0.0),
color2: Tuple[float, ...] = (1.0, 1.0, 1.0),
device: Optional[torch.types._device] = None,
) -> Meshes:
"""
Returns a mesh of squares in the xy-plane where each unit is one of the two given
colors and adjacent squares have opposite colors.
Args:
radius: how many squares in each direction from the origin
color1: background color
color2: foreground color (must have the same number of channels as color1)
Returns:
new Meshes object containing one mesh.
"""

if device is None:
device = torch.device("cpu")
if radius < 1:
raise ValueError("radius must be > 0")

num_verts_per_row = 2 * radius + 1

# construct 2D grid of 3D vertices
x = torch.arange(-radius, radius + 1, device=device)
grid_y, grid_x = meshgrid_ij(x, x)
verts = torch.stack(
[grid_x, grid_y, torch.zeros((2 * radius + 1, 2 * radius + 1))], dim=-1
)
verts = verts.view(1, -1, 3)

top_triangle_idx = torch.arange(0, num_verts_per_row * (num_verts_per_row - 1))
top_triangle_idx = torch.stack(
[
top_triangle_idx,
top_triangle_idx + 1,
top_triangle_idx + num_verts_per_row + 1,
],
dim=-1,
)

bottom_triangle_idx = top_triangle_idx[:, [0, 2, 1]] + torch.tensor(
[0, 0, num_verts_per_row - 1]
)

faces = torch.zeros(
(1, len(top_triangle_idx) + len(bottom_triangle_idx), 3),
dtype=torch.long,
device=device,
)
faces[0, ::2] = top_triangle_idx
faces[0, 1::2] = bottom_triangle_idx

# construct range of indices that excludes the boundary to avoid wrong triangles
indexing_range = torch.arange(0, 2 * num_verts_per_row * num_verts_per_row).view(
num_verts_per_row, num_verts_per_row, 2
)
indexing_range = indexing_range[:-1, :-1] # removes boundaries from list of indices
indexing_range = indexing_range.reshape(
2 * (num_verts_per_row - 1) * (num_verts_per_row - 1)
)

faces = faces[:, indexing_range]

# adding color
colors = torch.tensor(color1).repeat(2 * num_verts_per_row * num_verts_per_row, 1)
colors[2::4] = torch.tensor(color2)
colors[3::4] = torch.tensor(color2)
colors = colors[None, indexing_range, None, None]

texture_atlas = TexturesAtlas(colors)

return Meshes(verts=verts, faces=faces, textures=texture_atlas)
21 changes: 21 additions & 0 deletions tests/test_checkerboard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from pytorch3d.utils import checkerboard

from .common_testing import TestCaseMixin


class TestCheckerboard(TestCaseMixin, unittest.TestCase):
def test_simple(self):
board = checkerboard(5)
verts = board.verts_packed()
expect = torch.tensor([5.0, 5.0, 0])
self.assertClose(verts.min(dim=0).values, -expect)
self.assertClose(verts.max(dim=0).values, expect)

0 comments on commit ce3fce4

Please sign in to comment.