Skip to content

Commit 34b1b4a

Browse files
bottlerfacebook-github-bot
authored andcommitted
defaulted grid_sizes in points2vols
Summary: Fix #873, that grid_sizes defaults to the wrong dtype in points2volumes code, and mask doesn't have a proper default. Reviewed By: nikhilaravi Differential Revision: D31503545 fbshipit-source-id: fa32a1a6074fc7ac7bdb362edfb5e5839866a472
1 parent 2f2466f commit 34b1b4a

File tree

5 files changed

+24
-4
lines changed

5 files changed

+24
-4
lines changed

pytorch3d/common/workaround/symeig3x3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import math
8-
from typing import Tuple, Optional
8+
from typing import Optional, Tuple
99

1010
import torch
1111
import torch.nn.functional as F

pytorch3d/ops/points_to_volumes.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ def add_points_features_to_volume_densities_features(
364364
# grid sizes shape (minibatch, 3)
365365
grid_sizes = (
366366
torch.LongTensor(list(volume_densities.shape[2:]))
367-
.to(volume_densities)
367+
.to(volume_densities.device)
368368
.expand(volume_densities.shape[0], 3)
369369
)
370370

@@ -386,6 +386,10 @@ def add_points_features_to_volume_densities_features(
386386
splat = False
387387
else:
388388
raise ValueError('No such interpolation mode "%s"' % mode)
389+
390+
if mask is None:
391+
mask = points_3d.new_ones(1).expand(points_3d.shape[:2])
392+
389393
volume_densities, volume_features = _points_to_volumes(
390394
points_3d,
391395
points_features,

tests/bm_symeig3x3.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@
66

77

88
from itertools import product
9-
from typing import Callable, Any
9+
from typing import Any, Callable
1010

1111
import torch
1212
from common_testing import get_random_cuda_device
1313
from fvcore.common.benchmark import benchmark
1414
from pytorch3d.common.workaround import symeig3x3
1515
from test_symeig3x3 import TestSymEig3x3
1616

17+
1718
torch.set_num_threads(1)
1819

1920
CUDA_DEVICE = get_random_cuda_device()

tests/test_iou_box3d.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from pytorch3d.ops.iou_box3d import _box_planes, _box_triangles, box3d_overlap
1717
from pytorch3d.transforms.rotation_conversions import random_rotation
1818

19+
1920
OBJECTRON_TO_PYTORCH3D_FACE_IDX = [0, 4, 6, 2, 1, 5, 7, 3]
2021
DATA_DIR = get_tests_dir() / "data"
2122
DEBUG = False

tests/test_points_to_volumes.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
import numpy as np
1313
import torch
1414
from common_testing import TestCaseMixin
15-
from pytorch3d.ops import add_pointclouds_to_volumes
15+
from pytorch3d.ops import (
16+
add_pointclouds_to_volumes,
17+
add_points_features_to_volume_densities_features,
18+
)
1619
from pytorch3d.ops.points_to_volumes import _points_to_volumes
1720
from pytorch3d.ops.sample_points_from_meshes import sample_points_from_meshes
1821
from pytorch3d.structures.meshes import Meshes
@@ -373,6 +376,17 @@ def test_from_point_cloud(self, interp_mode="trilinear"):
373376
else:
374377
self.assertTrue(torch.isfinite(field.grad.data).all())
375378

379+
def test_defaulted_arguments(self):
380+
points = torch.rand(30, 1000, 3)
381+
features = torch.rand(30, 1000, 5)
382+
_, densities = add_points_features_to_volume_densities_features(
383+
points,
384+
features,
385+
torch.zeros(30, 1, 32, 32, 32),
386+
torch.zeros(30, 5, 32, 32, 32),
387+
)
388+
self.assertClose(torch.sum(densities), torch.tensor(30 * 1000.0), atol=0.1)
389+
376390
def _check_volume_slice_color_density(
377391
self, V, split_dim, interp_mode, clr_gt, slice_type, border=3
378392
):

0 commit comments

Comments
 (0)