Skip to content

Commit ce3fce4

Browse files
Adding a Checkerboard mesh utility to Pytorch3d
Summary: Adding a checkerboard mesh utility to Pytorch3d. Reviewed By: bottler Differential Revision: D39718916 fbshipit-source-id: d43cd30e566b5db068bae6eed0388057634428c8
1 parent f34da3d commit ce3fce4

File tree

3 files changed

+111
-0
lines changed

3 files changed

+111
-0
lines changed

pytorch3d/utils/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
pulsar_from_cameras_projection,
1111
pulsar_from_opencv_projection,
1212
)
13+
from .checkerboard import checkerboard
1314
from .ico_sphere import ico_sphere
1415
from .torus import torus
1516

pytorch3d/utils/checkerboard.py

+89
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
from typing import Optional, Tuple
9+
10+
import torch
11+
from pytorch3d.common.compat import meshgrid_ij
12+
from pytorch3d.renderer.mesh.textures import TexturesAtlas
13+
from pytorch3d.structures.meshes import Meshes
14+
15+
16+
def checkerboard(
17+
radius: int = 4,
18+
color1: Tuple[float, ...] = (0.0, 0.0, 0.0),
19+
color2: Tuple[float, ...] = (1.0, 1.0, 1.0),
20+
device: Optional[torch.types._device] = None,
21+
) -> Meshes:
22+
"""
23+
Returns a mesh of squares in the xy-plane where each unit is one of the two given
24+
colors and adjacent squares have opposite colors.
25+
Args:
26+
radius: how many squares in each direction from the origin
27+
color1: background color
28+
color2: foreground color (must have the same number of channels as color1)
29+
Returns:
30+
new Meshes object containing one mesh.
31+
"""
32+
33+
if device is None:
34+
device = torch.device("cpu")
35+
if radius < 1:
36+
raise ValueError("radius must be > 0")
37+
38+
num_verts_per_row = 2 * radius + 1
39+
40+
# construct 2D grid of 3D vertices
41+
x = torch.arange(-radius, radius + 1, device=device)
42+
grid_y, grid_x = meshgrid_ij(x, x)
43+
verts = torch.stack(
44+
[grid_x, grid_y, torch.zeros((2 * radius + 1, 2 * radius + 1))], dim=-1
45+
)
46+
verts = verts.view(1, -1, 3)
47+
48+
top_triangle_idx = torch.arange(0, num_verts_per_row * (num_verts_per_row - 1))
49+
top_triangle_idx = torch.stack(
50+
[
51+
top_triangle_idx,
52+
top_triangle_idx + 1,
53+
top_triangle_idx + num_verts_per_row + 1,
54+
],
55+
dim=-1,
56+
)
57+
58+
bottom_triangle_idx = top_triangle_idx[:, [0, 2, 1]] + torch.tensor(
59+
[0, 0, num_verts_per_row - 1]
60+
)
61+
62+
faces = torch.zeros(
63+
(1, len(top_triangle_idx) + len(bottom_triangle_idx), 3),
64+
dtype=torch.long,
65+
device=device,
66+
)
67+
faces[0, ::2] = top_triangle_idx
68+
faces[0, 1::2] = bottom_triangle_idx
69+
70+
# construct range of indices that excludes the boundary to avoid wrong triangles
71+
indexing_range = torch.arange(0, 2 * num_verts_per_row * num_verts_per_row).view(
72+
num_verts_per_row, num_verts_per_row, 2
73+
)
74+
indexing_range = indexing_range[:-1, :-1] # removes boundaries from list of indices
75+
indexing_range = indexing_range.reshape(
76+
2 * (num_verts_per_row - 1) * (num_verts_per_row - 1)
77+
)
78+
79+
faces = faces[:, indexing_range]
80+
81+
# adding color
82+
colors = torch.tensor(color1).repeat(2 * num_verts_per_row * num_verts_per_row, 1)
83+
colors[2::4] = torch.tensor(color2)
84+
colors[3::4] = torch.tensor(color2)
85+
colors = colors[None, indexing_range, None, None]
86+
87+
texture_atlas = TexturesAtlas(colors)
88+
89+
return Meshes(verts=verts, faces=faces, textures=texture_atlas)

tests/test_checkerboard.py

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from pytorch3d.utils import checkerboard
11+
12+
from .common_testing import TestCaseMixin
13+
14+
15+
class TestCheckerboard(TestCaseMixin, unittest.TestCase):
16+
def test_simple(self):
17+
board = checkerboard(5)
18+
verts = board.verts_packed()
19+
expect = torch.tensor([5.0, 5.0, 0])
20+
self.assertClose(verts.min(dim=0).values, -expect)
21+
self.assertClose(verts.max(dim=0).values, expect)

0 commit comments

Comments
 (0)