Skip to content

Commit cac6cb1

Browse files
ignacio-roccofacebook-github-bot
authored andcommitted
Update NDC raysampler for non-square convention (#29)
Summary: - Old NDC convention had xy coords in [-1,1]x[-1,1] - New NDC convention has xy coords in [-1, 1]x[-u, u] or [-u, u]x[-1, 1] where u > 1 is the aspect ratio of the image. This PR fixes the NDC raysampler to use the new convention. Partial fix for #868 Pull Request resolved: fairinternal/pytorch3d#29 Reviewed By: davnov134 Differential Revision: D31926148 Pulled By: bottler fbshipit-source-id: c6c42c60d1473b04e60ceb49c8c10951ddf03c74
1 parent bfeb82e commit cac6cb1

File tree

4 files changed

+118
-35
lines changed

4 files changed

+118
-35
lines changed

pytorch3d/renderer/implicit/raysampling.py

+15-8
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,8 @@ class NDCGridRaysampler(GridRaysampler):
139139
have uniformly-spaced z-coordinates between a predefined minimum and maximum depth.
140140
141141
`NDCGridRaysampler` follows the screen conventions of the `Meshes` and `Pointclouds`
142-
renderers. I.e. the border of the leftmost / rightmost / topmost / bottommost pixel
143-
has coordinates 1.0 / -1.0 / 1.0 / -1.0 respectively.
142+
renderers. I.e. the pixel coordinates are in [-1, 1]x[-u, u] or [-u, u]x[-1, 1]
143+
where u > 1 is the aspect ratio of the image.
144144
"""
145145

146146
def __init__(
@@ -159,13 +159,20 @@ def __init__(
159159
min_depth: The minimum depth of a ray-point.
160160
max_depth: The maximum depth of a ray-point.
161161
"""
162-
half_pix_width = 1.0 / image_width
163-
half_pix_height = 1.0 / image_height
162+
if image_width >= image_height:
163+
range_x = image_width / image_height
164+
range_y = 1.0
165+
else:
166+
range_x = 1.0
167+
range_y = image_height / image_width
168+
169+
half_pix_width = range_x / image_width
170+
half_pix_height = range_y / image_height
164171
super().__init__(
165-
min_x=1.0 - half_pix_width,
166-
max_x=-1.0 + half_pix_width,
167-
min_y=1.0 - half_pix_height,
168-
max_y=-1.0 + half_pix_height,
172+
min_x=range_x - half_pix_width,
173+
max_x=-range_x + half_pix_width,
174+
min_y=range_y - half_pix_height,
175+
max_y=-range_y + half_pix_height,
169176
image_width=image_width,
170177
image_height=image_height,
171178
n_pts_per_ray=n_pts_per_ray,

tests/test_raysampling.py

+76-6
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,69 @@
2424
from test_cameras import init_random_cameras
2525

2626

27+
class TestNDCRaysamplerConvention(TestCaseMixin, unittest.TestCase):
28+
def setUp(self) -> None:
29+
torch.manual_seed(42)
30+
31+
def test_ndc_convention(
32+
self,
33+
h=428,
34+
w=760,
35+
):
36+
device = torch.device("cuda")
37+
38+
camera = init_random_cameras(PerspectiveCameras, 1, random_z=True).to(device)
39+
40+
depth_map = torch.ones((1, 1, h, w)).to(device)
41+
42+
xyz = ray_bundle_to_ray_points(
43+
NDCGridRaysampler(
44+
image_width=w,
45+
image_height=h,
46+
n_pts_per_ray=1,
47+
min_depth=1.0,
48+
max_depth=1.0,
49+
)(camera)._replace(lengths=depth_map[:, 0, ..., None])
50+
).view(1, -1, 3)
51+
52+
# project pointcloud
53+
xy = camera.transform_points(xyz)[:, :, :2].squeeze()
54+
55+
xy_grid = self._get_ndc_grid(h, w, device)
56+
57+
self.assertClose(
58+
xy,
59+
xy_grid,
60+
atol=1e-4,
61+
)
62+
63+
def _get_ndc_grid(self, h, w, device):
64+
if w >= h:
65+
range_x = w / h
66+
range_y = 1.0
67+
else:
68+
range_x = 1.0
69+
range_y = h / w
70+
71+
half_pix_width = range_x / w
72+
half_pix_height = range_y / h
73+
74+
min_x = range_x - half_pix_width
75+
max_x = -range_x + half_pix_width
76+
min_y = range_y - half_pix_height
77+
max_y = -range_y + half_pix_height
78+
79+
y_grid, x_grid = torch.meshgrid(
80+
torch.linspace(min_y, max_y, h, dtype=torch.float32),
81+
torch.linspace(min_x, max_x, w, dtype=torch.float32),
82+
)
83+
84+
x_points = x_grid.contiguous().view(-1).to(device)
85+
y_points = y_grid.contiguous().view(-1).to(device)
86+
xy = torch.stack((x_points, y_points), dim=1)
87+
return xy
88+
89+
2790
class TestRaysampling(TestCaseMixin, unittest.TestCase):
2891
def setUp(self) -> None:
2992
torch.manual_seed(42)
@@ -147,12 +210,19 @@ def test_raysamplers(
147210

148211
if issubclass(raysampler_type, NDCGridRaysampler):
149212
# adjust the gt bounds for NDCGridRaysampler
150-
half_pix_width = 1.0 / image_width
151-
half_pix_height = 1.0 / image_height
152-
min_x_ = 1.0 - half_pix_width
153-
max_x_ = -1.0 + half_pix_width
154-
min_y_ = 1.0 - half_pix_height
155-
max_y_ = -1.0 + half_pix_height
213+
if image_width >= image_height:
214+
range_x = image_width / image_height
215+
range_y = 1.0
216+
else:
217+
range_x = 1.0
218+
range_y = image_height / image_width
219+
220+
half_pix_width = range_x / image_width
221+
half_pix_height = range_y / image_height
222+
min_x_ = range_x - half_pix_width
223+
max_x_ = -range_x + half_pix_width
224+
min_y_ = range_y - half_pix_height
225+
max_y_ = -range_y + half_pix_height
156226
else:
157227
min_x_ = min_x
158228
max_x_ = max_x

tests/test_render_implicit.py

+14-10
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,12 @@ def test_input_types(self):
159159
with self.assertRaises(ValueError):
160160
renderer(cameras=cameras, volumetric_function=bad_volumetric_function)
161161

162-
def test_compare_with_meshes_renderer(
163-
self, batch_size=11, image_size=100, sphere_diameter=0.6
162+
def test_compare_with_meshes_renderer(self):
163+
self._compare_with_meshes_renderer(image_size=(200, 100))
164+
self._compare_with_meshes_renderer(image_size=(100, 200))
165+
166+
def _compare_with_meshes_renderer(
167+
self, image_size, batch_size=11, sphere_diameter=0.6
164168
):
165169
"""
166170
Generate a spherical RGB volumetric function and its corresponding mesh
@@ -169,18 +173,16 @@ def test_compare_with_meshes_renderer(
169173
"""
170174

171175
# generate NDC camera extrinsics and intrinsics
172-
cameras = init_cameras(
173-
batch_size, image_size=[image_size, image_size], ndc=True
174-
)
176+
cameras = init_cameras(batch_size, image_size=image_size, ndc=True)
175177

176178
# get rand offset of the volume
177179
sphere_centroid = torch.randn(batch_size, 3, device=cameras.device) * 0.1
178180
sphere_centroid.requires_grad = True
179181

180182
# init the grid raysampler with the ndc grid
181183
raysampler = NDCGridRaysampler(
182-
image_width=image_size,
183-
image_height=image_size,
184+
image_width=image_size[1],
185+
image_height=image_size[0],
184186
n_pts_per_ray=256,
185187
min_depth=0.1,
186188
max_depth=2.0,
@@ -336,9 +338,11 @@ def test_compare_with_meshes_renderer(
336338
self.assertClose(mu_diff, torch.zeros_like(mu_diff), atol=5e-2)
337339
self.assertClose(std_diff, torch.zeros_like(std_diff), atol=6e-2)
338340

339-
def test_rotating_gif(
340-
self, n_frames=50, fps=15, image_size=(100, 100), sphere_diameter=0.5
341-
):
341+
def test_rotating_gif(self):
342+
self._rotating_gif(image_size=(200, 100))
343+
self._rotating_gif(image_size=(100, 200))
344+
345+
def _rotating_gif(self, image_size, n_frames=50, fps=15, sphere_diameter=0.5):
342346
"""
343347
Render a gif animation of a rotating sphere (runs only if `DEBUG==True`).
344348
"""

tests/test_render_volumes.py

+13-11
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def init_cameras(
164164
p0 = torch.ones(batch_size, 2, device=device)
165165
p0[:, 0] *= image_size[1] * 0.5
166166
p0[:, 1] *= image_size[0] * 0.5
167-
focal = image_size[0] * torch.ones(batch_size, device=device)
167+
focal = max(*image_size) * torch.ones(batch_size, device=device)
168168

169169
# convert to a Camera object
170170
cameras = PerspectiveCameras(focal, p0, R=R, T=T, device=device)
@@ -295,17 +295,15 @@ def test_input_types(self, batch_size: int = 10):
295295
_validate_ray_bundle_variables(*bad_ray_bundle)
296296

297297
def test_compare_with_pointclouds_renderer(
298-
self, batch_size=11, volume_size=(30, 30, 30), image_size=200
298+
self, batch_size=11, volume_size=(30, 30, 30), image_size=(200, 250)
299299
):
300300
"""
301301
Generate a volume and its corresponding point cloud and check whether
302302
PointsRenderer returns the same images as the corresponding VolumeRenderer.
303303
"""
304304

305305
# generate NDC camera extrinsics and intrinsics
306-
cameras = init_cameras(
307-
batch_size, image_size=[image_size, image_size], ndc=True
308-
)
306+
cameras = init_cameras(batch_size, image_size=image_size, ndc=True)
309307

310308
# init the boundary volume
311309
for shape in ("sphere", "cube"):
@@ -340,10 +338,10 @@ def test_compare_with_pointclouds_renderer(
340338

341339
# init the grid raysampler with the ndc grid
342340
coord_range = 1.0
343-
half_pix_size = coord_range / image_size
341+
half_pix_size = coord_range / max(*image_size)
344342
raysampler = NDCGridRaysampler(
345-
image_width=image_size,
346-
image_height=image_size,
343+
image_width=image_size[1],
344+
image_height=image_size[0],
347345
n_pts_per_ray=256,
348346
min_depth=0.1,
349347
max_depth=2.0,
@@ -499,8 +497,12 @@ def test_monte_carlo_rendering(
499497
images_opacities_mc.permute(0, 3, 1, 2), images_opacities_mc_, atol=1e-4
500498
)
501499

502-
def test_rotating_gif(
503-
self, n_frames=50, fps=15, volume_size=(100, 100, 100), image_size=(100, 100)
500+
def test_rotating_gif(self):
501+
self._rotating_gif(image_size=(200, 100))
502+
self._rotating_gif(image_size=(100, 200))
503+
504+
def _rotating_gif(
505+
self, image_size, n_frames=50, fps=15, volume_size=(100, 100, 100)
504506
):
505507
"""
506508
Render a gif animation of a rotating cube/sphere (runs only if `DEBUG==True`).
@@ -586,7 +588,7 @@ def test_rotating_cube_volume_render(self):
586588

587589
# batch_size = 4 sides of the cube
588590
batch_size = 4
589-
image_size = (50, 50)
591+
image_size = (50, 40)
590592

591593
for volume_size in ([25, 25, 25],):
592594
for sample_mode in ("bilinear", "nearest"):

0 commit comments

Comments
 (0)