Skip to content

Commit 262c1bf

Browse files
nikhilaravifacebook-github-bot
authored andcommitted
Join points as batch
Summary: Function to join a list of pointclouds as a batch similar to the corresponding function for Meshes. Reviewed By: bottler Differential Revision: D33145906 fbshipit-source-id: 160639ebb5065e4fae1a1aa43117172719f3871b
1 parent eb2bbf8 commit 262c1bf

File tree

2 files changed

+102
-1
lines changed

2 files changed

+102
-1
lines changed

pytorch3d/structures/pointclouds.py

+37
Original file line numberDiff line numberDiff line change
@@ -1178,3 +1178,40 @@ def inside_box(self, box):
11781178

11791179
coord_inside = (points_packed >= box[:, 0]) * (points_packed <= box[:, 1])
11801180
return coord_inside.all(dim=-1)
1181+
1182+
1183+
def join_pointclouds_as_batch(pointclouds: Sequence[Pointclouds]):
1184+
"""
1185+
Merge a list of Pointclouds objects into a single batched Pointclouds
1186+
object. All pointclouds must be on the same device.
1187+
1188+
Args:
1189+
batch: List of Pointclouds objects each with batch dim [b1, b2, ..., bN]
1190+
Returns:
1191+
pointcloud: Poinclouds object with all input pointclouds collated into
1192+
a single object with batch dim = sum(b1, b2, ..., bN)
1193+
"""
1194+
if isinstance(pointclouds, Pointclouds) or not isinstance(pointclouds, Sequence):
1195+
raise ValueError("Wrong first argument to join_points_as_batch.")
1196+
1197+
device = pointclouds[0].device
1198+
if not all(p.device == device for p in pointclouds):
1199+
raise ValueError("Pointclouds must all be on the same device")
1200+
1201+
kwargs = {}
1202+
for field in ("points", "normals", "features"):
1203+
field_list = [getattr(p, field + "_list")() for p in pointclouds]
1204+
if None in field_list:
1205+
if field == "points":
1206+
raise ValueError("Pointclouds cannot have their points set to None!")
1207+
if not all(f is None for f in field_list):
1208+
raise ValueError(
1209+
f"Pointclouds in the batch have some fields '{field}'"
1210+
+ " defined and some set to None."
1211+
)
1212+
field_list = None
1213+
else:
1214+
field_list = [p for points in field_list for p in points]
1215+
kwargs[field] = field_list
1216+
1217+
return Pointclouds(**kwargs)

tests/test_pointclouds.py

+65-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import torch
1313
from common_testing import TestCaseMixin
1414
from pytorch3d.structures import utils as struct_utils
15-
from pytorch3d.structures.pointclouds import Pointclouds
15+
from pytorch3d.structures.pointclouds import Pointclouds, join_pointclouds_as_batch
1616

1717

1818
class TestPointclouds(TestCaseMixin, unittest.TestCase):
@@ -1098,6 +1098,70 @@ def test_subsample(self):
10981098
for length, points_ in zip(lengths_max_4, pcl_copy2.points_list()):
10991099
self.assertEqual(points_.shape, (length, 3))
11001100

1101+
def test_join_pointclouds_as_batch(self):
1102+
"""
1103+
Test join_pointclouds_as_batch
1104+
"""
1105+
1106+
def check_item(x, y):
1107+
self.assertEqual(x is None, y is None)
1108+
if x is not None:
1109+
self.assertClose(torch.cat([x, x, x]), y)
1110+
1111+
def check_triple(points, points3):
1112+
"""
1113+
Verify that points3 is three copies of points.
1114+
"""
1115+
check_item(points.points_padded(), points3.points_padded())
1116+
check_item(points.normals_padded(), points3.normals_padded())
1117+
check_item(points.features_padded(), points3.features_padded())
1118+
1119+
lengths = [4, 5, 13, 3]
1120+
points = [torch.rand(length, 3) for length in lengths]
1121+
features = [torch.rand(length, 5) for length in lengths]
1122+
normals = [torch.rand(length, 3) for length in lengths]
1123+
1124+
# Test with normals and features present
1125+
pcl = Pointclouds(points=points, features=features, normals=normals)
1126+
pcl3 = join_pointclouds_as_batch([pcl] * 3)
1127+
check_triple(pcl, pcl3)
1128+
1129+
# Test with normals and features present for tensor backed pointclouds
1130+
N, P, D = 5, 30, 4
1131+
pcl = Pointclouds(
1132+
points=torch.rand(N, P, 3),
1133+
features=torch.rand(N, P, D),
1134+
normals=torch.rand(N, P, 3),
1135+
)
1136+
pcl3 = join_pointclouds_as_batch([pcl] * 3)
1137+
check_triple(pcl, pcl3)
1138+
1139+
# Test without normals
1140+
pcl_nonormals = Pointclouds(points=points, features=features)
1141+
pcl3 = join_pointclouds_as_batch([pcl_nonormals] * 3)
1142+
check_triple(pcl_nonormals, pcl3)
1143+
1144+
# Test without features
1145+
pcl_nofeats = Pointclouds(points=points, normals=normals)
1146+
pcl3 = join_pointclouds_as_batch([pcl_nofeats] * 3)
1147+
check_triple(pcl_nofeats, pcl3)
1148+
1149+
# Check error raised if all pointclouds in the batch
1150+
# are not consistent in including normals/features
1151+
with self.assertRaisesRegex(ValueError, "some set to None"):
1152+
join_pointclouds_as_batch([pcl, pcl_nonormals, pcl_nonormals])
1153+
with self.assertRaisesRegex(ValueError, "some set to None"):
1154+
join_pointclouds_as_batch([pcl, pcl_nofeats, pcl_nofeats])
1155+
1156+
# Check error if first input is a single pointclouds object
1157+
# instead of a list
1158+
with self.assertRaisesRegex(ValueError, "Wrong first argument"):
1159+
join_pointclouds_as_batch(pcl)
1160+
1161+
# Check error if all pointclouds are not on the same device
1162+
with self.assertRaisesRegex(ValueError, "same device"):
1163+
join_pointclouds_as_batch([pcl, pcl.to("cuda:0")])
1164+
11011165
@staticmethod
11021166
def compute_packed_with_init(
11031167
num_clouds: int = 10, max_p: int = 100, features: int = 300

0 commit comments

Comments
 (0)