|
12 | 12 | import torch
|
13 | 13 | from common_testing import TestCaseMixin
|
14 | 14 | from pytorch3d.structures import utils as struct_utils
|
15 |
| -from pytorch3d.structures.pointclouds import Pointclouds |
| 15 | +from pytorch3d.structures.pointclouds import Pointclouds, join_pointclouds_as_batch |
16 | 16 |
|
17 | 17 |
|
18 | 18 | class TestPointclouds(TestCaseMixin, unittest.TestCase):
|
@@ -1098,6 +1098,70 @@ def test_subsample(self):
|
1098 | 1098 | for length, points_ in zip(lengths_max_4, pcl_copy2.points_list()):
|
1099 | 1099 | self.assertEqual(points_.shape, (length, 3))
|
1100 | 1100 |
|
| 1101 | + def test_join_pointclouds_as_batch(self): |
| 1102 | + """ |
| 1103 | + Test join_pointclouds_as_batch |
| 1104 | + """ |
| 1105 | + |
| 1106 | + def check_item(x, y): |
| 1107 | + self.assertEqual(x is None, y is None) |
| 1108 | + if x is not None: |
| 1109 | + self.assertClose(torch.cat([x, x, x]), y) |
| 1110 | + |
| 1111 | + def check_triple(points, points3): |
| 1112 | + """ |
| 1113 | + Verify that points3 is three copies of points. |
| 1114 | + """ |
| 1115 | + check_item(points.points_padded(), points3.points_padded()) |
| 1116 | + check_item(points.normals_padded(), points3.normals_padded()) |
| 1117 | + check_item(points.features_padded(), points3.features_padded()) |
| 1118 | + |
| 1119 | + lengths = [4, 5, 13, 3] |
| 1120 | + points = [torch.rand(length, 3) for length in lengths] |
| 1121 | + features = [torch.rand(length, 5) for length in lengths] |
| 1122 | + normals = [torch.rand(length, 3) for length in lengths] |
| 1123 | + |
| 1124 | + # Test with normals and features present |
| 1125 | + pcl = Pointclouds(points=points, features=features, normals=normals) |
| 1126 | + pcl3 = join_pointclouds_as_batch([pcl] * 3) |
| 1127 | + check_triple(pcl, pcl3) |
| 1128 | + |
| 1129 | + # Test with normals and features present for tensor backed pointclouds |
| 1130 | + N, P, D = 5, 30, 4 |
| 1131 | + pcl = Pointclouds( |
| 1132 | + points=torch.rand(N, P, 3), |
| 1133 | + features=torch.rand(N, P, D), |
| 1134 | + normals=torch.rand(N, P, 3), |
| 1135 | + ) |
| 1136 | + pcl3 = join_pointclouds_as_batch([pcl] * 3) |
| 1137 | + check_triple(pcl, pcl3) |
| 1138 | + |
| 1139 | + # Test without normals |
| 1140 | + pcl_nonormals = Pointclouds(points=points, features=features) |
| 1141 | + pcl3 = join_pointclouds_as_batch([pcl_nonormals] * 3) |
| 1142 | + check_triple(pcl_nonormals, pcl3) |
| 1143 | + |
| 1144 | + # Test without features |
| 1145 | + pcl_nofeats = Pointclouds(points=points, normals=normals) |
| 1146 | + pcl3 = join_pointclouds_as_batch([pcl_nofeats] * 3) |
| 1147 | + check_triple(pcl_nofeats, pcl3) |
| 1148 | + |
| 1149 | + # Check error raised if all pointclouds in the batch |
| 1150 | + # are not consistent in including normals/features |
| 1151 | + with self.assertRaisesRegex(ValueError, "some set to None"): |
| 1152 | + join_pointclouds_as_batch([pcl, pcl_nonormals, pcl_nonormals]) |
| 1153 | + with self.assertRaisesRegex(ValueError, "some set to None"): |
| 1154 | + join_pointclouds_as_batch([pcl, pcl_nofeats, pcl_nofeats]) |
| 1155 | + |
| 1156 | + # Check error if first input is a single pointclouds object |
| 1157 | + # instead of a list |
| 1158 | + with self.assertRaisesRegex(ValueError, "Wrong first argument"): |
| 1159 | + join_pointclouds_as_batch(pcl) |
| 1160 | + |
| 1161 | + # Check error if all pointclouds are not on the same device |
| 1162 | + with self.assertRaisesRegex(ValueError, "same device"): |
| 1163 | + join_pointclouds_as_batch([pcl, pcl.to("cuda:0")]) |
| 1164 | + |
1101 | 1165 | @staticmethod
|
1102 | 1166 | def compute_packed_with_init(
|
1103 | 1167 | num_clouds: int = 10, max_p: int = 100, features: int = 300
|
|
0 commit comments