Skip to content

Commit 4d043fc

Browse files
bottlerfacebook-github-bot
authored andcommitted
PyTorch 1.7 compatibility
Summary: Small changes discovered based on circleCI failures. Reviewed By: patricklabatut Differential Revision: D34426807 fbshipit-source-id: 819860f34b2f367dd24057ca7490284204180a13
1 parent f816568 commit 4d043fc

File tree

9 files changed

+34
-23
lines changed

9 files changed

+34
-23
lines changed

pytorch3d/common/compat.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Tuple
7+
from typing import Sequence, Tuple, Union
88

99
import torch
1010

@@ -57,4 +57,16 @@ def eigh(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # pragma: no co
5757
"""
5858
if hasattr(torch, "linalg") and hasattr(torch.linalg, "eigh"):
5959
return torch.linalg.eigh(A)
60-
return torch.symeig(A, eigenvalues=True)
60+
return torch.symeig(A, eigenvectors=True)
61+
62+
63+
def meshgrid_ij(
64+
*A: Union[torch.Tensor, Sequence[torch.Tensor]]
65+
) -> Tuple[torch.Tensor, ...]: # pragma: no cover
66+
"""
67+
Like torch.meshgrid was before PyTorch 1.10.0, i.e. with indexing set to ij
68+
"""
69+
if "indexing" in torch.meshgrid.__kwdefaults__:
70+
# PyTorch >= 1.10.0
71+
return torch.meshgrid(*A, indexing="ij")
72+
return torch.meshgrid(*A)

pytorch3d/io/mtl_io.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import torch
1414
import torch.nn.functional as F
1515
from iopath.common.file_io import PathManager
16+
from pytorch3d.common.compat import meshgrid_ij
1617
from pytorch3d.common.datatypes import Device
1718
from pytorch3d.io.utils import _open_file, _read_image
1819

@@ -273,7 +274,7 @@ def make_material_atlas(
273274

274275
# Meshgrid returns (row, column) i.e (Y, X)
275276
# Change order to (X, Y) to make the grid.
276-
Y, X = torch.meshgrid(rng, rng)
277+
Y, X = meshgrid_ij(rng, rng)
277278
# pyre-fixme[28]: Unexpected keyword argument `axis`.
278279
grid = torch.stack([X, Y], axis=-1) # (R, R, 2)
279280

pytorch3d/ops/cubify.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import torch
99
import torch.nn.functional as F
10+
from pytorch3d.common.compat import meshgrid_ij
1011
from pytorch3d.structures import Meshes
1112

1213

@@ -195,9 +196,7 @@ def cubify(voxels, thresh, device=None, align: str = "topleft") -> Meshes:
195196
# NF x 3
196197
grid_faces = torch.stack(grid_faces, dim=1)
197198

198-
y, x, z = torch.meshgrid(
199-
torch.arange(H + 1), torch.arange(W + 1), torch.arange(D + 1)
200-
)
199+
y, x, z = meshgrid_ij(torch.arange(H + 1), torch.arange(W + 1), torch.arange(D + 1))
201200
y = y.to(device=device, dtype=torch.float32)
202201
x = x.to(device=device, dtype=torch.float32)
203202
z = z.to(device=device, dtype=torch.float32)

pytorch3d/renderer/implicit/raysampling.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Optional
99

1010
import torch
11+
from pytorch3d.common.compat import meshgrid_ij
1112
from pytorch3d.renderer.cameras import CamerasBase
1213
from pytorch3d.renderer.implicit.utils import RayBundle
1314
from torch.nn import functional as F
@@ -103,7 +104,7 @@ def __init__(
103104
_xy_grid = torch.stack(
104105
tuple(
105106
reversed(
106-
torch.meshgrid(
107+
meshgrid_ij(
107108
torch.linspace(min_y, max_y, image_height, dtype=torch.float32),
108109
torch.linspace(min_x, max_x, image_width, dtype=torch.float32),
109110
)

pytorch3d/structures/volumes.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
from typing import List, Optional, Tuple, Union
99

1010
import torch
11+
from pytorch3d.common.compat import meshgrid_ij
12+
from pytorch3d.common.datatypes import Device, make_device
13+
from pytorch3d.transforms import Scale, Transform3d
1114

12-
from ..common.datatypes import Device, make_device
13-
from ..transforms import Scale, Transform3d
1415
from . import utils as struct_utils
1516

1617

@@ -393,7 +394,7 @@ def _calculate_coordinate_grid(
393394
]
394395

395396
# generate per-coord meshgrids
396-
Z, Y, X = torch.meshgrid(vol_axes)
397+
Z, Y, X = meshgrid_ij(vol_axes)
397398

398399
# stack the coord grids ... this order matches the coordinate convention
399400
# of torch.nn.grid_sample

tests/test_point_mesh_distance.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,7 @@
1111
from common_testing import TestCaseMixin, get_random_cuda_device
1212
from pytorch3d import _C
1313
from pytorch3d.loss import point_mesh_edge_distance, point_mesh_face_distance
14-
from pytorch3d.structures import (
15-
Meshes,
16-
Pointclouds,
17-
packed_to_list,
18-
)
14+
from pytorch3d.structures import Meshes, Pointclouds, packed_to_list
1915

2016

2117
class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):

tests/test_pointclouds.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -1033,7 +1033,7 @@ def inside_box_naive(cloud, box_min, box_max):
10331033
for i, cloud in enumerate(clouds.points_list()):
10341034
within_box_naive.append(inside_box_naive(cloud, box[i, 0], box[i, 1]))
10351035
within_box_naive = torch.cat(within_box_naive, 0)
1036-
self.assertClose(within_box, within_box_naive)
1036+
self.assertTrue(torch.equal(within_box, within_box_naive))
10371037

10381038
# box of shape 2x3
10391039
box2 = box[0, :]
@@ -1044,13 +1044,12 @@ def inside_box_naive(cloud, box_min, box_max):
10441044
for cloud in clouds.points_list():
10451045
within_box_naive2.append(inside_box_naive(cloud, box2[0], box2[1]))
10461046
within_box_naive2 = torch.cat(within_box_naive2, 0)
1047-
self.assertClose(within_box2, within_box_naive2)
1048-
1047+
self.assertTrue(torch.equal(within_box2, within_box_naive2))
10491048
# box of shape 1x2x3
10501049
box3 = box2.expand(1, 2, 3)
10511050

10521051
within_box3 = clouds.inside_box(box3)
1053-
self.assertClose(within_box2, within_box3)
1052+
self.assertTrue(torch.equal(within_box2, within_box3))
10541053

10551054
# invalid box
10561055
invalid_box = torch.cat(

tests/test_raysampling.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import torch
1111
from common_testing import TestCaseMixin
12+
from pytorch3d.common.compat import meshgrid_ij
1213
from pytorch3d.ops import eyes
1314
from pytorch3d.renderer import (
1415
MonteCarloRaysampler,
@@ -86,7 +87,7 @@ def _get_ndc_grid(self, h, w, device):
8687
min_y = range_y - half_pix_height
8788
max_y = -range_y + half_pix_height
8889

89-
y_grid, x_grid = torch.meshgrid(
90+
y_grid, x_grid = meshgrid_ij(
9091
torch.linspace(min_y, max_y, h, dtype=torch.float32),
9192
torch.linspace(min_x, max_x, w, dtype=torch.float32),
9293
)
@@ -540,7 +541,7 @@ def test_jiggle(self):
540541
self.assertTupleEqual(out.shape, data.shape)
541542

542543
# Check `out` is in ascending order
543-
self.assertGreater(torch.diff(out, dim=-1).min(), 0)
544+
self.assertGreater((out[..., 1:] - out[..., :-1]).min(), 0)
544545

545546
self.assertConstant(out[..., :-1] < data[..., 1:], True)
546547
self.assertConstant(data[..., :-1] < out[..., 1:], True)

tests/test_rendering_utils.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import numpy as np
1111
import torch
1212
from common_testing import TestCaseMixin
13+
from pytorch3d.common.compat import meshgrid_ij
1314
from pytorch3d.ops import eyes
1415
from pytorch3d.renderer import (
1516
AlphaCompositor,
@@ -129,8 +130,8 @@ def test_ndc_grid_sample_rendering(self):
129130
point_radius = 0.015
130131
n_pts = n_grid_pts * n_grid_pts
131132
pts = torch.stack(
132-
torch.meshgrid(
133-
[torch.linspace(-grid_scale, grid_scale, n_grid_pts)] * 2, indexing="ij"
133+
meshgrid_ij(
134+
[torch.linspace(-grid_scale, grid_scale, n_grid_pts)] * 2,
134135
),
135136
dim=-1,
136137
)

0 commit comments

Comments
 (0)