Skip to content

Commit 94da884

Browse files
shapovalovfacebook-github-bot
authored andcommitted
Align_corners switch in Volumes
Summary: Porting this commit by davnov134 . fairinternal/pytorch3d@93a3a62#diff-a8e107ebe039de52ca112ac6ddfba6ebccd53b4f53030b986e13f019fe57a378 Capability to interpret world/local coordinates with various align_corners semantics. Reviewed By: bottler Differential Revision: D51855420 fbshipit-source-id: 834cd220c25d7f0143d8a55ba880da5977099dd6
1 parent fbc6725 commit 94da884

File tree

5 files changed

+91
-9
lines changed

5 files changed

+91
-9
lines changed

pytorch3d/implicitron/tools/model_io.py

+7
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,13 @@ def save_model(model, stats, fl, optimizer=None, cfg=None):
9898
return flstats, flmodel, flopt
9999

100100

101+
def save_stats(stats, fl, cfg=None):
102+
flstats = get_stats_path(fl)
103+
logger.info("saving model stats to %s" % flstats)
104+
stats.save(flstats)
105+
return flstats
106+
107+
101108
def load_model(fl, map_location: Optional[dict]):
102109
flstats = get_stats_path(fl)
103110
flmodel = get_model_path(fl)

pytorch3d/ops/points_to_volumes.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ def add_pointclouds_to_volumes(
291291
mask=mask,
292292
mode=mode,
293293
rescale_features=rescale_features,
294+
align_corners=initial_volumes.get_align_corners(),
294295
_python=_python,
295296
)
296297

@@ -310,6 +311,7 @@ def add_points_features_to_volume_densities_features(
310311
grid_sizes: Optional[torch.LongTensor] = None,
311312
rescale_features: bool = True,
312313
_python: bool = False,
314+
align_corners: bool = True,
313315
) -> Tuple[torch.Tensor, torch.Tensor]:
314316
"""
315317
Convert a batch of point clouds represented with tensors of per-point
@@ -356,6 +358,7 @@ def add_points_features_to_volume_densities_features(
356358
output densities are just summed without rescaling, so
357359
you may need to rescale them afterwards.
358360
_python: Set to True to use a pure Python implementation.
361+
align_corners: as for grid_sample.
359362
Returns:
360363
volume_features: Output volume of shape `(minibatch, feature_dim, D, H, W)`
361364
volume_densities: Occupancy volume of shape `(minibatch, 1, D, H, W)`
@@ -409,7 +412,7 @@ def add_points_features_to_volume_densities_features(
409412
grid_sizes,
410413
1.0, # point_weight
411414
mask,
412-
True, # align_corners
415+
align_corners, # align_corners
413416
splat,
414417
)
415418

pytorch3d/renderer/implicit/renderer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -382,9 +382,9 @@ def forward(
382382
rays_densities = torch.nn.functional.grid_sample(
383383
volumes_densities,
384384
rays_points_local_flat,
385-
align_corners=True,
386385
mode=self._sample_mode,
387386
padding_mode=self._padding_mode,
387+
align_corners=self._volumes.get_align_corners(),
388388
)
389389

390390
# permute the dimensions & reshape densities after sampling
@@ -400,9 +400,9 @@ def forward(
400400
rays_features = torch.nn.functional.grid_sample(
401401
volumes_features,
402402
rays_points_local_flat,
403-
align_corners=True,
404403
mode=self._sample_mode,
405404
padding_mode=self._padding_mode,
405+
align_corners=self._volumes.get_align_corners(),
406406
)
407407

408408
# permute the dimensions & reshape features after sampling

pytorch3d/structures/volumes.py

+35-6
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ class Volumes:
8585
are linearly interpolated over the spatial dimensions of the volume.
8686
- Note that the convention is the same as for the 5D version of the
8787
`torch.nn.functional.grid_sample` function called with
88-
`align_corners==True`.
88+
the same value of `align_corners` argument.
8989
- Note that the local coordinate convention of `Volumes`
9090
(+X = left to right, +Y = top to bottom, +Z = away from the user)
9191
is *different* from the world coordinate convention of the
@@ -143,7 +143,7 @@ class Volumes:
143143
torch.nn.functional.grid_sample(
144144
v.densities(),
145145
v.get_coord_grid(world_coordinates=False),
146-
align_corners=True,
146+
align_corners=align_corners,
147147
) == v.densities(),
148148
149149
i.e. sampling the volume at trivial local coordinates
@@ -157,6 +157,7 @@ def __init__(
157157
features: Optional[_TensorBatch] = None,
158158
voxel_size: _VoxelSize = 1.0,
159159
volume_translation: _Translation = (0.0, 0.0, 0.0),
160+
align_corners: bool = True,
160161
) -> None:
161162
"""
162163
Args:
@@ -186,6 +187,10 @@ def __init__(
186187
b) a Tensor of shape (3,)
187188
c) a Tensor of shape (minibatch, 3)
188189
d) a Tensor of shape (1,) (square voxels)
190+
**align_corners**: If set (default), the coordinates of the corner voxels are
191+
exactly −1 or +1 in the local coordinate system. Otherwise, the coordinates
192+
correspond to the centers of the corner voxels. Cf. the namesake argument to
193+
`torch.nn.functional.grid_sample`.
189194
"""
190195

191196
# handle densities
@@ -206,6 +211,7 @@ def __init__(
206211
voxel_size=voxel_size,
207212
volume_translation=volume_translation,
208213
device=self.device,
214+
align_corners=align_corners,
209215
)
210216

211217
# handle features
@@ -336,6 +342,13 @@ def features_list(self) -> List[torch.Tensor]:
336342
return None
337343
return self._features_densities_list(features_)
338344

345+
def get_align_corners(self) -> bool:
346+
"""
347+
Return whether the corners of the voxels should be aligned with the
348+
image pixels.
349+
"""
350+
return self.locator._align_corners
351+
339352
def _features_densities_list(self, x: torch.Tensor) -> List[torch.Tensor]:
340353
"""
341354
Retrieve the list representation of features/densities.
@@ -576,7 +589,7 @@ class VolumeLocator:
576589
are linearly interpolated over the spatial dimensions of the volume.
577590
- Note that the convention is the same as for the 5D version of the
578591
`torch.nn.functional.grid_sample` function called with
579-
`align_corners==True`.
592+
the same value of `align_corners` argument.
580593
- Note that the local coordinate convention of `VolumeLocator`
581594
(+X = left to right, +Y = top to bottom, +Z = away from the user)
582595
is *different* from the world coordinate convention of the
@@ -634,7 +647,7 @@ class VolumeLocator:
634647
torch.nn.functional.grid_sample(
635648
v.densities(),
636649
v.get_coord_grid(world_coordinates=False),
637-
align_corners=True,
650+
align_corners=align_corners,
638651
) == v.densities(),
639652
640653
i.e. sampling the volume at trivial local coordinates
@@ -651,6 +664,7 @@ def __init__(
651664
device: torch.device,
652665
voxel_size: _VoxelSize = 1.0,
653666
volume_translation: _Translation = (0.0, 0.0, 0.0),
667+
align_corners: bool = True,
654668
):
655669
"""
656670
**batch_size** : Batch size of the underlying grids
@@ -674,15 +688,21 @@ def __init__(
674688
b) a Tensor of shape (3,)
675689
c) a Tensor of shape (minibatch, 3)
676690
d) a Tensor of shape (1,) (square voxels)
691+
**align_corners**: If set (default), the coordinates of the corner voxels are
692+
exactly −1 or +1 in the local coordinate system. Otherwise, the coordinates
693+
correspond to the centers of the corner voxels. Cf. the namesake argument to
694+
`torch.nn.functional.grid_sample`.
677695
"""
678696
self.device = device
679697
self._batch_size = batch_size
680698
self._grid_sizes = self._convert_grid_sizes2tensor(grid_sizes)
681699
self._resolution = tuple(torch.max(self._grid_sizes.cpu(), dim=0).values)
700+
self._align_corners = align_corners
682701

683702
# set the local_to_world transform
684703
self._set_local_to_world_transform(
685-
voxel_size=voxel_size, volume_translation=volume_translation
704+
voxel_size=voxel_size,
705+
volume_translation=volume_translation,
686706
)
687707

688708
def _convert_grid_sizes2tensor(
@@ -806,8 +826,17 @@ def _calculate_coordinate_grid(
806826
grid_sizes = self.get_grid_sizes()
807827

808828
# generate coordinate axes
829+
def corner_coord_adjustment(r):
830+
return 0.0 if self._align_corners else 1.0 / r
831+
809832
vol_axes = [
810-
torch.linspace(-1.0, 1.0, r, dtype=torch.float32, device=self.device)
833+
torch.linspace(
834+
-1.0 + corner_coord_adjustment(r),
835+
1.0 - corner_coord_adjustment(r),
836+
r,
837+
dtype=torch.float32,
838+
device=self.device,
839+
)
811840
for r in (de, he, wi)
812841
]
813842

tests/test_volumes.py

+43
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,49 @@ def test_coord_grid_convention(
312312
).permute(0, 2, 3, 4, 1)
313313
self.assertClose(grid_world_resampled, grid_world, atol=1e-7)
314314

315+
for align_corners in [True, False]:
316+
v_trivial = Volumes(densities=densities, align_corners=align_corners)
317+
318+
# check the case with x_world=(0,0,0)
319+
pts_world = torch.zeros(
320+
num_volumes, 1, 3, device=device, dtype=torch.float32
321+
)
322+
pts_local = v_trivial.world_to_local_coords(pts_world)
323+
pts_local_expected = torch.zeros_like(pts_local)
324+
self.assertClose(pts_local, pts_local_expected)
325+
326+
# check the case with x_world=(-2, 3, -2)
327+
pts_world_tuple = [-2, 3, -2]
328+
pts_world = torch.tensor(
329+
pts_world_tuple, device=device, dtype=torch.float32
330+
)[None, None].repeat(num_volumes, 1, 1)
331+
pts_local = v_trivial.world_to_local_coords(pts_world)
332+
pts_local_expected = torch.tensor(
333+
[-1, 1, -1], device=device, dtype=torch.float32
334+
)[None, None].repeat(num_volumes, 1, 1)
335+
self.assertClose(pts_local, pts_local_expected)
336+
337+
# # check that the central voxel has coords x_world=(0, 0, 0) and x_local(0, 0, 0)
338+
grid_world = v_trivial.get_coord_grid(world_coordinates=True)
339+
grid_local = v_trivial.get_coord_grid(world_coordinates=False)
340+
for grid in (grid_world, grid_local):
341+
x0 = grid[0, :, :, 2, 0]
342+
y0 = grid[0, :, 3, :, 1]
343+
z0 = grid[0, 2, :, :, 2]
344+
for coord_line in (x0, y0, z0):
345+
self.assertClose(
346+
coord_line, torch.zeros_like(coord_line), atol=1e-7
347+
)
348+
349+
# resample grid_world using grid_sampler with local coords
350+
# -> make sure the resampled version is the same as original
351+
grid_world_resampled = torch.nn.functional.grid_sample(
352+
grid_world.permute(0, 4, 1, 2, 3),
353+
grid_local,
354+
align_corners=align_corners,
355+
).permute(0, 2, 3, 4, 1)
356+
self.assertClose(grid_world_resampled, grid_world, atol=1e-7)
357+
315358
def test_coord_grid_convention_heterogeneous(
316359
self, num_channels=4, dtype=torch.float32
317360
):

0 commit comments

Comments
 (0)