Skip to content

Commit d19e624

Browse files
davidsonicfacebook-github-bot
authored andcommitted
Update Rasterizer and add end2end fisheye integration test
Summary: 1) Update rasterizer/point rasterizer to accommodate fisheyecamera. Specifically, transform_points is in placement of explicit transform compositions. 2) In rasterizer unittests, update corresponding tests for rasterizer and point_rasterizer. Address comments to test fisheye against perspective camera when distortions are turned off. 3) Address comments to add end2end test for fisheyecameras. In test_render_meshes, fisheyecameras are added to camera enuerations whenever possible. 4) Test renderings with fisheyecameras of different params on cow mesh. 5) Use compositions for linear cameras whenever possible. Reviewed By: kjchalup Differential Revision: D38932736 fbshipit-source-id: 5b7074fc001f2390f4cf43c7267a8b37fd987547
1 parent b0515e1 commit d19e624

File tree

63 files changed

+566
-76
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+566
-76
lines changed

pytorch3d/renderer/cameras.py

+22-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import math
88
import warnings
9-
from typing import List, Optional, Sequence, Tuple, Union
9+
from typing import Callable, List, Optional, Sequence, Tuple, Union
1010

1111
import numpy as np
1212
import torch
@@ -91,7 +91,7 @@ class CamerasBase(TensorProperties):
9191
# When joining objects into a batch, they will have to agree.
9292
_SHARED_FIELDS: Tuple[str, ...] = ()
9393

94-
def get_projection_transform(self):
94+
def get_projection_transform(self, **kwargs):
9595
"""
9696
Calculate the projective transformation matrix.
9797
@@ -1841,3 +1841,23 @@ def get_screen_to_ndc_transform(
18411841
image_size=image_size,
18421842
).inverse()
18431843
return transform
1844+
1845+
1846+
def try_get_projection_transform(cameras, kwargs) -> Optional[Callable]:
1847+
"""
1848+
Try block to get projection transform.
1849+
1850+
Args:
1851+
cameras instance, can be linear cameras or nonliear cameras
1852+
1853+
Returns:
1854+
If the camera implemented projection_transform, return the
1855+
projection transform; Otherwise, return None
1856+
"""
1857+
1858+
transform = None
1859+
try:
1860+
transform = cameras.get_projection_transform(**kwargs)
1861+
except NotImplementedError:
1862+
pass
1863+
return transform

pytorch3d/renderer/mesh/rasterizer.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import torch
1111
import torch.nn as nn
12+
from pytorch3d.renderer.cameras import try_get_projection_transform
1213

1314
from .rasterize_meshes import rasterize_meshes
1415

@@ -197,12 +198,19 @@ def transform(self, meshes_world, **kwargs) -> torch.Tensor:
197198
verts_view = cameras.get_world_to_view_transform(**kwargs).transform_points(
198199
verts_world, eps=eps
199200
)
200-
# view to NDC transform
201+
# Call transform_points instead of explicitly composing transforms to handle
202+
# the case, where camera class does not have a projection matrix form.
203+
verts_proj = cameras.transform_points(verts_world, eps=eps)
201204
to_ndc_transform = cameras.get_ndc_camera_transform(**kwargs)
202-
projection_transform = cameras.get_projection_transform(**kwargs).compose(
203-
to_ndc_transform
204-
)
205-
verts_ndc = projection_transform.transform_points(verts_view, eps=eps)
205+
projection_transform = try_get_projection_transform(cameras, kwargs)
206+
if projection_transform is not None:
207+
projection_transform = projection_transform.compose(to_ndc_transform)
208+
verts_ndc = projection_transform.transform_points(verts_view, eps=eps)
209+
else:
210+
# Call transform_points instead of explicitly composing transforms to handle
211+
# the case, where camera class does not have a projection matrix form.
212+
verts_proj = cameras.transform_points(verts_world, eps=eps)
213+
verts_ndc = to_ndc_transform.transform_points(verts_proj, eps=eps)
206214

207215
verts_ndc[..., 2] = verts_view[..., 2]
208216
meshes_ndc = meshes_world.update_padded(new_verts_padded=verts_ndc)

pytorch3d/renderer/points/rasterizer.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import torch
1212
import torch.nn as nn
13+
from pytorch3d.renderer.cameras import try_get_projection_transform
1314
from pytorch3d.structures import Pointclouds
1415

1516
from .rasterize_points import rasterize_points
@@ -103,12 +104,16 @@ def transform(self, point_clouds, **kwargs) -> Pointclouds:
103104
pts_view = cameras.get_world_to_view_transform(**kwargs).transform_points(
104105
pts_world, eps=eps
105106
)
106-
# view to NDC transform
107107
to_ndc_transform = cameras.get_ndc_camera_transform(**kwargs)
108-
projection_transform = cameras.get_projection_transform(**kwargs).compose(
109-
to_ndc_transform
110-
)
111-
pts_ndc = projection_transform.transform_points(pts_view, eps=eps)
108+
projection_transform = try_get_projection_transform(cameras, kwargs)
109+
if projection_transform is not None:
110+
projection_transform = projection_transform.compose(to_ndc_transform)
111+
pts_ndc = projection_transform.transform_points(pts_view, eps=eps)
112+
else:
113+
# Call transform_points instead of explicitly composing transforms to handle
114+
# the case, where camera class does not have a projection matrix form.
115+
pts_proj = cameras.transform_points(pts_world, eps=eps)
116+
pts_ndc = to_ndc_transform.transform_points(pts_proj, eps=eps)
112117

113118
pts_ndc[..., 2] = pts_view[..., 2]
114119
point_clouds = point_clouds.update_padded(pts_ndc)
5.48 KB
1.64 KB

tests/test_rasterizer.py

+159
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
PointsRasterizer,
2222
RasterizationSettings,
2323
)
24+
from pytorch3d.renderer.fisheyecameras import FishEyeCameras
2425
from pytorch3d.renderer.opengl.rasterizer_opengl import (
2526
_check_cameras,
2627
_check_raster_settings,
@@ -51,6 +52,9 @@ class TestMeshRasterizer(unittest.TestCase):
5152
def test_simple_sphere(self):
5253
self._simple_sphere(MeshRasterizer)
5354

55+
def test_simple_sphere_fisheye(self):
56+
self._simple_sphere_fisheye_against_perspective(MeshRasterizer)
57+
5458
def test_simple_sphere_opengl(self):
5559
self._simple_sphere(MeshRasterizerOpenGL)
5660

@@ -155,6 +159,91 @@ def _simple_sphere(self, rasterizer_type):
155159

156160
self.assertTrue(torch.allclose(image, image_ref))
157161

162+
def _simple_sphere_fisheye_against_perspective(self, rasterizer_type):
163+
device = torch.device("cuda:0")
164+
165+
# Init mesh
166+
sphere_mesh = ico_sphere(5, device)
167+
168+
# Init rasterizer settings
169+
R, T = look_at_view_transform(2.7, 0, 0)
170+
171+
# Init Fisheye camera params
172+
focal = torch.tensor([[1.7321]], dtype=torch.float32)
173+
principal_point = torch.tensor([[0.0101, -0.0101]])
174+
perspective_cameras = PerspectiveCameras(
175+
R=R,
176+
T=T,
177+
focal_length=focal,
178+
principal_point=principal_point,
179+
device="cuda:0",
180+
)
181+
fisheye_cameras = FishEyeCameras(
182+
device=device,
183+
R=R,
184+
T=T,
185+
focal_length=focal,
186+
principal_point=principal_point,
187+
world_coordinates=True,
188+
use_radial=False,
189+
use_tangential=False,
190+
use_thin_prism=False,
191+
)
192+
raster_settings = RasterizationSettings(
193+
image_size=512, blur_radius=0.0, faces_per_pixel=1, bin_size=0
194+
)
195+
196+
# Init rasterizer
197+
perspective_rasterizer = rasterizer_type(
198+
cameras=perspective_cameras, raster_settings=raster_settings
199+
)
200+
fisheye_rasterizer = rasterizer_type(
201+
cameras=fisheye_cameras, raster_settings=raster_settings
202+
)
203+
204+
####################################################################################
205+
# Test rasterizing a single mesh comparing fisheye camera against perspective camera
206+
####################################################################################
207+
208+
perspective_fragments = perspective_rasterizer(sphere_mesh)
209+
perspective_image = perspective_fragments.pix_to_face[0, ..., 0].squeeze().cpu()
210+
# Convert pix_to_face to a binary mask
211+
perspective_image[perspective_image >= 0] = 1.0
212+
perspective_image[perspective_image < 0] = 0.0
213+
214+
if DEBUG:
215+
Image.fromarray((perspective_image.numpy() * 255).astype(np.uint8)).save(
216+
DATA_DIR
217+
/ f"DEBUG_test_perspective_rasterized_sphere_{rasterizer_type.__name__}.png"
218+
)
219+
220+
fisheye_fragments = fisheye_rasterizer(sphere_mesh)
221+
fisheye_image = fisheye_fragments.pix_to_face[0, ..., 0].squeeze().cpu()
222+
# Convert pix_to_face to a binary mask
223+
fisheye_image[fisheye_image >= 0] = 1.0
224+
fisheye_image[fisheye_image < 0] = 0.0
225+
226+
if DEBUG:
227+
Image.fromarray((fisheye_image.numpy() * 255).astype(np.uint8)).save(
228+
DATA_DIR
229+
/ f"DEBUG_test_fisheye_rasterized_sphere_{rasterizer_type.__name__}.png"
230+
)
231+
232+
self.assertTrue(torch.allclose(fisheye_image, perspective_image))
233+
234+
##################################
235+
# 2. Test with a batch of meshes
236+
##################################
237+
238+
batch_size = 10
239+
sphere_meshes = sphere_mesh.extend(batch_size)
240+
fragments = fisheye_rasterizer(sphere_meshes)
241+
for i in range(batch_size):
242+
image = fragments.pix_to_face[i, ..., 0].squeeze().cpu()
243+
image[image >= 0] = 1.0
244+
image[image < 0] = 0.0
245+
self.assertTrue(torch.allclose(image, perspective_image))
246+
158247
def test_simple_to(self):
159248
# Check that to() works without a cameras object.
160249
device = torch.device("cuda:0")
@@ -412,6 +501,76 @@ def test_simple_sphere(self):
412501
image[image < 0] = 0.0
413502
self.assertTrue(torch.allclose(image, image_ref[..., 0]))
414503

504+
def test_simple_sphere_fisheye_against_perspective(self):
505+
device = torch.device("cuda:0")
506+
507+
# Rescale image_ref to the 0 - 1 range and convert to a binary mask.
508+
sphere_mesh = ico_sphere(1, device)
509+
verts_padded = sphere_mesh.verts_padded()
510+
verts_padded[..., 1] += 0.2
511+
verts_padded[..., 0] += 0.2
512+
pointclouds = Pointclouds(points=verts_padded)
513+
R, T = look_at_view_transform(2.7, 0.0, 0.0)
514+
perspective_cameras = PerspectiveCameras(
515+
R=R,
516+
T=T,
517+
device=device,
518+
)
519+
fisheye_cameras = FishEyeCameras(
520+
device=device,
521+
R=R,
522+
T=T,
523+
world_coordinates=True,
524+
use_radial=False,
525+
use_tangential=False,
526+
use_thin_prism=False,
527+
)
528+
raster_settings = PointsRasterizationSettings(
529+
image_size=256, radius=5e-2, points_per_pixel=1
530+
)
531+
532+
#################################
533+
# 1. Test init without cameras.
534+
##################################
535+
536+
# Initialize without passing in the cameras
537+
rasterizer = PointsRasterizer()
538+
539+
# Check that omitting the cameras in both initialization
540+
# and the forward pass throws an error:
541+
with self.assertRaisesRegex(ValueError, "Cameras must be specified"):
542+
rasterizer(pointclouds)
543+
544+
########################################################################################
545+
# 2. Test rasterizing a single pointcloud with fisheye camera agasint perspective camera
546+
########################################################################################
547+
548+
perspective_fragments = rasterizer(
549+
pointclouds, cameras=perspective_cameras, raster_settings=raster_settings
550+
)
551+
fisheye_fragments = rasterizer(
552+
pointclouds, cameras=fisheye_cameras, raster_settings=raster_settings
553+
)
554+
555+
# Convert idx to a binary mask
556+
perspective_image = perspective_fragments.idx[0, ..., 0].squeeze().cpu()
557+
perspective_image[perspective_image >= 0] = 1.0
558+
perspective_image[perspective_image < 0] = 0.0
559+
560+
fisheye_image = fisheye_fragments.idx[0, ..., 0].squeeze().cpu()
561+
fisheye_image[fisheye_image >= 0] = 1.0
562+
fisheye_image[fisheye_image < 0] = 0.0
563+
564+
if DEBUG:
565+
Image.fromarray((perspective_image.numpy() * 255).astype(np.uint8)).save(
566+
DATA_DIR / "DEBUG_test_rasterized_perspective_sphere_points.png"
567+
)
568+
Image.fromarray((fisheye_image.numpy() * 255).astype(np.uint8)).save(
569+
DATA_DIR / "DEBUG_test_rasterized_fisheye_sphere_points.png"
570+
)
571+
572+
self.assertTrue(torch.allclose(fisheye_image, perspective_image))
573+
415574
def test_simple_to(self):
416575
# Check that to() works without a cameras object.
417576
device = torch.device("cuda:0")

0 commit comments

Comments
 (0)