Skip to content

Commit bfeb82e

Browse files
bottlerfacebook-github-bot
authored andcommitted
some pointcloud typing
Summary: Make clear that features_padded() etc can return None Reviewed By: patricklabatut Differential Revision: D31795088 fbshipit-source-id: 7b0bbb6f3b7ad7f7b6e6a727129537af1d1873af
1 parent 73a14d7 commit bfeb82e

File tree

1 file changed

+61
-45
lines changed

1 file changed

+61
-45
lines changed

pytorch3d/structures/pointclouds.py

+61-45
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from itertools import zip_longest
8-
from typing import Sequence, Union
8+
from typing import List, Optional, Sequence, Tuple, Union
99

1010
import numpy as np
1111
import torch
@@ -240,7 +240,9 @@ def __init__(self, points, normals=None, features=None) -> None:
240240
if features_C is not None:
241241
self._C = features_C
242242

243-
def _parse_auxiliary_input(self, aux_input):
243+
def _parse_auxiliary_input(
244+
self, aux_input
245+
) -> Tuple[Optional[List[torch.Tensor]], Optional[torch.Tensor], Optional[int]]:
244246
"""
245247
Interpret the auxiliary inputs (normals, features) given to __init__.
246248
@@ -323,24 +325,26 @@ def __getitem__(self, index) -> "Pointclouds":
323325
Pointclouds object with selected clouds. The tensors are not cloned.
324326
"""
325327
normals, features = None, None
328+
normals_list = self.normals_list()
329+
features_list = self.features_list()
326330
if isinstance(index, int):
327331
points = [self.points_list()[index]]
328-
if self.normals_list() is not None:
329-
normals = [self.normals_list()[index]]
330-
if self.features_list() is not None:
331-
features = [self.features_list()[index]]
332+
if normals_list is not None:
333+
normals = [normals_list[index]]
334+
if features_list is not None:
335+
features = [features_list[index]]
332336
elif isinstance(index, slice):
333337
points = self.points_list()[index]
334-
if self.normals_list() is not None:
335-
normals = self.normals_list()[index]
336-
if self.features_list() is not None:
337-
features = self.features_list()[index]
338+
if normals_list is not None:
339+
normals = normals_list[index]
340+
if features_list is not None:
341+
features = features_list[index]
338342
elif isinstance(index, list):
339343
points = [self.points_list()[i] for i in index]
340-
if self.normals_list() is not None:
341-
normals = [self.normals_list()[i] for i in index]
342-
if self.features_list() is not None:
343-
features = [self.features_list()[i] for i in index]
344+
if normals_list is not None:
345+
normals = [normals_list[i] for i in index]
346+
if features_list is not None:
347+
features = [features_list[i] for i in index]
344348
elif isinstance(index, torch.Tensor):
345349
if index.dim() != 1 or index.dtype.is_floating_point:
346350
raise IndexError(index)
@@ -351,10 +355,10 @@ def __getitem__(self, index) -> "Pointclouds":
351355
index = index.squeeze(1) if index.numel() > 0 else index
352356
index = index.tolist()
353357
points = [self.points_list()[i] for i in index]
354-
if self.normals_list() is not None:
355-
normals = [self.normals_list()[i] for i in index]
356-
if self.features_list() is not None:
357-
features = [self.features_list()[i] for i in index]
358+
if normals_list is not None:
359+
normals = [normals_list[i] for i in index]
360+
if features_list is not None:
361+
features = [features_list[i] for i in index]
358362
else:
359363
raise IndexError(index)
360364

@@ -369,7 +373,7 @@ def isempty(self) -> bool:
369373
"""
370374
return self._N == 0 or self.valid.eq(False).all()
371375

372-
def points_list(self):
376+
def points_list(self) -> List[torch.Tensor]:
373377
"""
374378
Get the list representation of the points.
375379
@@ -388,9 +392,10 @@ def points_list(self):
388392
self._points_list = points_list
389393
return self._points_list
390394

391-
def normals_list(self):
395+
def normals_list(self) -> Optional[List[torch.Tensor]]:
392396
"""
393-
Get the list representation of the normals.
397+
Get the list representation of the normals,
398+
or None if there are no normals.
394399
395400
Returns:
396401
list of tensors of normals of shape (P_n, 3).
@@ -404,9 +409,10 @@ def normals_list(self):
404409
)
405410
return self._normals_list
406411

407-
def features_list(self):
412+
def features_list(self) -> Optional[List[torch.Tensor]]:
408413
"""
409-
Get the list representation of the features.
414+
Get the list representation of the features,
415+
or None if there are no features.
410416
411417
Returns:
412418
list of tensors of features of shape (P_n, C).
@@ -420,7 +426,7 @@ def features_list(self):
420426
)
421427
return self._features_list
422428

423-
def points_packed(self):
429+
def points_packed(self) -> torch.Tensor:
424430
"""
425431
Get the packed representation of the points.
426432
@@ -430,22 +436,24 @@ def points_packed(self):
430436
self._compute_packed()
431437
return self._points_packed
432438

433-
def normals_packed(self):
439+
def normals_packed(self) -> Optional[torch.Tensor]:
434440
"""
435441
Get the packed representation of the normals.
436442
437443
Returns:
438-
tensor of normals of shape (sum(P_n), 3).
444+
tensor of normals of shape (sum(P_n), 3),
445+
or None if there are no normals.
439446
"""
440447
self._compute_packed()
441448
return self._normals_packed
442449

443-
def features_packed(self):
450+
def features_packed(self) -> Optional[torch.Tensor]:
444451
"""
445452
Get the packed representation of the features.
446453
447454
Returns:
448-
tensor of features of shape (sum(P_n), C).
455+
tensor of features of shape (sum(P_n), C),
456+
or None if there are no features
449457
"""
450458
self._compute_packed()
451459
return self._features_packed
@@ -483,7 +491,7 @@ def num_points_per_cloud(self):
483491
"""
484492
return self._num_points_per_cloud
485493

486-
def points_padded(self):
494+
def points_padded(self) -> torch.Tensor:
487495
"""
488496
Get the padded representation of the points.
489497
@@ -493,19 +501,21 @@ def points_padded(self):
493501
self._compute_padded()
494502
return self._points_padded
495503

496-
def normals_padded(self):
504+
def normals_padded(self) -> Optional[torch.Tensor]:
497505
"""
498-
Get the padded representation of the normals.
506+
Get the padded representation of the normals,
507+
or None if there are no normals.
499508
500509
Returns:
501510
tensor of normals of shape (N, max(P_n), 3).
502511
"""
503512
self._compute_padded()
504513
return self._normals_padded
505514

506-
def features_padded(self):
515+
def features_padded(self) -> Optional[torch.Tensor]:
507516
"""
508-
Get the padded representation of the features.
517+
Get the padded representation of the features,
518+
or None if there are no features.
509519
510520
Returns:
511521
tensor of features of shape (N, max(P_n), 3).
@@ -562,16 +572,18 @@ def _compute_padded(self, refresh: bool = False):
562572
pad_value=0.0,
563573
equisized=self.equisized,
564574
)
565-
if self.normals_list() is not None:
575+
normals_list = self.normals_list()
576+
if normals_list is not None:
566577
self._normals_padded = struct_utils.list_to_padded(
567-
self.normals_list(),
578+
normals_list,
568579
(self._P, 3),
569580
pad_value=0.0,
570581
equisized=self.equisized,
571582
)
572-
if self.features_list() is not None:
583+
features_list = self.features_list()
584+
if features_list is not None:
573585
self._features_padded = struct_utils.list_to_padded(
574-
self.features_list(),
586+
features_list,
575587
(self._P, self._C),
576588
pad_value=0.0,
577589
equisized=self.equisized,
@@ -772,10 +784,12 @@ def get_cloud(self, index: int):
772784
)
773785
points = self.points_list()[index]
774786
normals, features = None, None
775-
if self.normals_list() is not None:
776-
normals = self.normals_list()[index]
777-
if self.features_list() is not None:
778-
features = self.features_list()[index]
787+
normals_list = self.normals_list()
788+
if normals_list is not None:
789+
normals = normals_list[index]
790+
features_list = self.features_list()
791+
if features_list is not None:
792+
features = features_list[index]
779793
return points, normals, features
780794

781795
# TODO(nikhilar) Move function to a utils file.
@@ -1022,13 +1036,15 @@ def extend(self, N: int):
10221036
new_points_list, new_normals_list, new_features_list = [], None, None
10231037
for points in self.points_list():
10241038
new_points_list.extend(points.clone() for _ in range(N))
1025-
if self.normals_list() is not None:
1039+
normals_list = self.normals_list()
1040+
if normals_list is not None:
10261041
new_normals_list = []
1027-
for normals in self.normals_list():
1042+
for normals in normals_list:
10281043
new_normals_list.extend(normals.clone() for _ in range(N))
1029-
if self.features_list() is not None:
1044+
features_list = self.features_list()
1045+
if features_list is not None:
10301046
new_features_list = []
1031-
for features in self.features_list():
1047+
for features in features_list:
10321048
new_features_list.extend(features.clone() for _ in range(N))
10331049
return self.__class__(
10341050
points=new_points_list, normals=new_normals_list, features=new_features_list

0 commit comments

Comments
 (0)