Skip to content

Commit 4281df1

Browse files
bottlerfacebook-github-bot
authored andcommitted
subsample pointclouds
Summary: New function to randomly subsample Pointclouds to a maximum size. Reviewed By: nikhilaravi Differential Revision: D30936533 fbshipit-source-id: 789eb5004b6a233034ec1c500f20f2d507a303ff
1 parent ee2b2fe commit 4281df1

File tree

3 files changed

+92
-21
lines changed

3 files changed

+92
-21
lines changed

pytorch3d/structures/pointclouds.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from itertools import zip_longest
8+
from typing import Sequence, Union
9+
10+
import numpy as np
711
import torch
812

913
from ..common.types import Device, make_device
@@ -841,6 +845,54 @@ def offset(self, offsets_packed):
841845
new_clouds = self.clone()
842846
return new_clouds.offset_(offsets_packed)
843847

848+
def subsample(self, max_points: Union[int, Sequence[int]]) -> "Pointclouds":
849+
"""
850+
Subsample each cloud so that it has at most max_points points.
851+
852+
Args:
853+
max_points: maximum number of points in each cloud.
854+
855+
Returns:
856+
new Pointclouds object, or self if nothing to be done.
857+
"""
858+
if isinstance(max_points, int):
859+
max_points = [max_points] * len(self)
860+
elif len(max_points) != len(self):
861+
raise ValueError("wrong number of max_points supplied")
862+
if all(
863+
int(n_points) <= int(max_)
864+
for n_points, max_ in zip(self.num_points_per_cloud(), max_points)
865+
):
866+
return self
867+
868+
points_list = []
869+
features_list = []
870+
normals_list = []
871+
for max_, n_points, points, features, normals in zip_longest(
872+
map(int, max_points),
873+
map(int, self.num_points_per_cloud()),
874+
self.points_list(),
875+
self.features_list() or (),
876+
self.normals_list() or (),
877+
):
878+
if n_points > max_:
879+
keep_np = np.random.choice(n_points, max_, replace=False)
880+
keep = torch.tensor(keep_np).to(points.device)
881+
points = points[keep]
882+
if features is not None:
883+
features = features[keep]
884+
if normals is not None:
885+
normals = normals[keep]
886+
points_list.append(points)
887+
features_list.append(features)
888+
normals_list.append(normals)
889+
890+
return Pointclouds(
891+
points=points_list,
892+
normals=self.normals_list() and normals_list,
893+
features=self.features_list() and features_list,
894+
)
895+
844896
def scale_(self, scale):
845897
"""
846898
Multiply the coordinates of this object by a scalar value.

pytorch3d/vis/plotly_vis.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import warnings
88
from typing import Dict, List, NamedTuple, Optional, Tuple, Union
99

10-
import numpy as np
1110
import plotly.graph_objects as go
1211
import torch
1312
from plotly.subplots import make_subplots
@@ -644,31 +643,12 @@ def _add_pointcloud_trace(
644643
max_points_per_pointcloud: the number of points to render, which are randomly sampled.
645644
marker_size: the size of the rendered points
646645
"""
647-
pointclouds = pointclouds.detach().cpu()
646+
pointclouds = pointclouds.detach().cpu().subsample(max_points_per_pointcloud)
648647
verts = pointclouds.points_packed()
649648
features = pointclouds.features_packed()
650649

651-
indices = None
652-
if pointclouds.num_points_per_cloud().max() > max_points_per_pointcloud:
653-
start_index = 0
654-
index_list = []
655-
for num_points in pointclouds.num_points_per_cloud():
656-
if num_points > max_points_per_pointcloud:
657-
indices_cloud = np.random.choice(
658-
num_points, max_points_per_pointcloud, replace=False
659-
)
660-
index_list.append(start_index + indices_cloud)
661-
else:
662-
index_list.append(start_index + np.arange(num_points))
663-
start_index += num_points
664-
indices = np.concatenate(index_list)
665-
verts = verts[indices]
666-
667650
color = None
668651
if features is not None:
669-
if indices is not None:
670-
# Only select features if we selected vertices above
671-
features = features[indices]
672652
if features.shape[1] == 4: # rgba
673653
template = "rgb(%d, %d, %d, %f)"
674654
rgb = (features[:, :3].clamp(0.0, 1.0) * 255).int()

tests/test_pointclouds.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,6 +1057,45 @@ def test_estimate_normals(self):
10571057
clouds.normals_packed(), torch.cat(normals_est_list, dim=0)
10581058
)
10591059

1060+
def test_subsample(self):
1061+
lengths = [4, 5, 13, 3]
1062+
points = [torch.rand(length, 3) for length in lengths]
1063+
features = [torch.rand(length, 5) for length in lengths]
1064+
normals = [torch.rand(length, 3) for length in lengths]
1065+
1066+
pcl1 = Pointclouds(points=points).cuda()
1067+
self.assertIs(pcl1, pcl1.subsample(13))
1068+
self.assertIs(pcl1, pcl1.subsample([6, 13, 13, 13]))
1069+
1070+
lengths_max_4 = torch.tensor([4, 4, 4, 3]).cuda()
1071+
for with_normals, with_features in itertools.product([True, False], repeat=2):
1072+
with self.subTest(f"{with_normals} {with_features}"):
1073+
pcl = Pointclouds(
1074+
points=points,
1075+
normals=normals if with_normals else None,
1076+
features=features if with_features else None,
1077+
)
1078+
pcl_copy = pcl.subsample(max_points=4)
1079+
for length, points_ in zip(lengths_max_4, pcl_copy.points_list()):
1080+
self.assertEqual(points_.shape, (length, 3))
1081+
if with_normals:
1082+
for length, normals_ in zip(lengths_max_4, pcl_copy.normals_list()):
1083+
self.assertEqual(normals_.shape, (length, 3))
1084+
else:
1085+
self.assertIsNone(pcl_copy.normals_list())
1086+
if with_features:
1087+
for length, features_ in zip(
1088+
lengths_max_4, pcl_copy.features_list()
1089+
):
1090+
self.assertEqual(features_.shape, (length, 5))
1091+
else:
1092+
self.assertIsNone(pcl_copy.features_list())
1093+
1094+
pcl2 = Pointclouds(points=points)
1095+
pcl_copy2 = pcl2.subsample(lengths_max_4)
1096+
for length, points_ in zip(lengths_max_4, pcl_copy2.points_list()):
1097+
self.assertEqual(points_.shape, (length, 3))
1098+
10601099
@staticmethod
10611100
def compute_packed_with_init(
10621101
num_clouds: int = 10, max_p: int = 100, features: int = 300

0 commit comments

Comments
 (0)