5
5
# LICENSE file in the root directory of this source tree.
6
6
7
7
from itertools import zip_longest
8
- from typing import Sequence , Union
8
+ from typing import List , Optional , Sequence , Tuple , Union
9
9
10
10
import numpy as np
11
11
import torch
@@ -240,7 +240,9 @@ def __init__(self, points, normals=None, features=None) -> None:
240
240
if features_C is not None :
241
241
self ._C = features_C
242
242
243
- def _parse_auxiliary_input (self , aux_input ):
243
+ def _parse_auxiliary_input (
244
+ self , aux_input
245
+ ) -> Tuple [Optional [List [torch .Tensor ]], Optional [torch .Tensor ], Optional [int ]]:
244
246
"""
245
247
Interpret the auxiliary inputs (normals, features) given to __init__.
246
248
@@ -323,24 +325,26 @@ def __getitem__(self, index) -> "Pointclouds":
323
325
Pointclouds object with selected clouds. The tensors are not cloned.
324
326
"""
325
327
normals , features = None , None
328
+ normals_list = self .normals_list ()
329
+ features_list = self .features_list ()
326
330
if isinstance (index , int ):
327
331
points = [self .points_list ()[index ]]
328
- if self . normals_list () is not None :
329
- normals = [self . normals_list () [index ]]
330
- if self . features_list () is not None :
331
- features = [self . features_list () [index ]]
332
+ if normals_list is not None :
333
+ normals = [normals_list [index ]]
334
+ if features_list is not None :
335
+ features = [features_list [index ]]
332
336
elif isinstance (index , slice ):
333
337
points = self .points_list ()[index ]
334
- if self . normals_list () is not None :
335
- normals = self . normals_list () [index ]
336
- if self . features_list () is not None :
337
- features = self . features_list () [index ]
338
+ if normals_list is not None :
339
+ normals = normals_list [index ]
340
+ if features_list is not None :
341
+ features = features_list [index ]
338
342
elif isinstance (index , list ):
339
343
points = [self .points_list ()[i ] for i in index ]
340
- if self . normals_list () is not None :
341
- normals = [self . normals_list () [i ] for i in index ]
342
- if self . features_list () is not None :
343
- features = [self . features_list () [i ] for i in index ]
344
+ if normals_list is not None :
345
+ normals = [normals_list [i ] for i in index ]
346
+ if features_list is not None :
347
+ features = [features_list [i ] for i in index ]
344
348
elif isinstance (index , torch .Tensor ):
345
349
if index .dim () != 1 or index .dtype .is_floating_point :
346
350
raise IndexError (index )
@@ -351,10 +355,10 @@ def __getitem__(self, index) -> "Pointclouds":
351
355
index = index .squeeze (1 ) if index .numel () > 0 else index
352
356
index = index .tolist ()
353
357
points = [self .points_list ()[i ] for i in index ]
354
- if self . normals_list () is not None :
355
- normals = [self . normals_list () [i ] for i in index ]
356
- if self . features_list () is not None :
357
- features = [self . features_list () [i ] for i in index ]
358
+ if normals_list is not None :
359
+ normals = [normals_list [i ] for i in index ]
360
+ if features_list is not None :
361
+ features = [features_list [i ] for i in index ]
358
362
else :
359
363
raise IndexError (index )
360
364
@@ -369,7 +373,7 @@ def isempty(self) -> bool:
369
373
"""
370
374
return self ._N == 0 or self .valid .eq (False ).all ()
371
375
372
- def points_list (self ):
376
+ def points_list (self ) -> List [ torch . Tensor ] :
373
377
"""
374
378
Get the list representation of the points.
375
379
@@ -388,9 +392,10 @@ def points_list(self):
388
392
self ._points_list = points_list
389
393
return self ._points_list
390
394
391
- def normals_list (self ):
395
+ def normals_list (self ) -> Optional [ List [ torch . Tensor ]] :
392
396
"""
393
- Get the list representation of the normals.
397
+ Get the list representation of the normals,
398
+ or None if there are no normals.
394
399
395
400
Returns:
396
401
list of tensors of normals of shape (P_n, 3).
@@ -404,9 +409,10 @@ def normals_list(self):
404
409
)
405
410
return self ._normals_list
406
411
407
- def features_list (self ):
412
+ def features_list (self ) -> Optional [ List [ torch . Tensor ]] :
408
413
"""
409
- Get the list representation of the features.
414
+ Get the list representation of the features,
415
+ or None if there are no features.
410
416
411
417
Returns:
412
418
list of tensors of features of shape (P_n, C).
@@ -420,7 +426,7 @@ def features_list(self):
420
426
)
421
427
return self ._features_list
422
428
423
- def points_packed (self ):
429
+ def points_packed (self ) -> torch . Tensor :
424
430
"""
425
431
Get the packed representation of the points.
426
432
@@ -430,22 +436,24 @@ def points_packed(self):
430
436
self ._compute_packed ()
431
437
return self ._points_packed
432
438
433
- def normals_packed (self ):
439
+ def normals_packed (self ) -> Optional [ torch . Tensor ] :
434
440
"""
435
441
Get the packed representation of the normals.
436
442
437
443
Returns:
438
- tensor of normals of shape (sum(P_n), 3).
444
+ tensor of normals of shape (sum(P_n), 3),
445
+ or None if there are no normals.
439
446
"""
440
447
self ._compute_packed ()
441
448
return self ._normals_packed
442
449
443
- def features_packed (self ):
450
+ def features_packed (self ) -> Optional [ torch . Tensor ] :
444
451
"""
445
452
Get the packed representation of the features.
446
453
447
454
Returns:
448
- tensor of features of shape (sum(P_n), C).
455
+ tensor of features of shape (sum(P_n), C),
456
+ or None if there are no features
449
457
"""
450
458
self ._compute_packed ()
451
459
return self ._features_packed
@@ -483,7 +491,7 @@ def num_points_per_cloud(self):
483
491
"""
484
492
return self ._num_points_per_cloud
485
493
486
- def points_padded (self ):
494
+ def points_padded (self ) -> torch . Tensor :
487
495
"""
488
496
Get the padded representation of the points.
489
497
@@ -493,19 +501,21 @@ def points_padded(self):
493
501
self ._compute_padded ()
494
502
return self ._points_padded
495
503
496
- def normals_padded (self ):
504
+ def normals_padded (self ) -> Optional [ torch . Tensor ] :
497
505
"""
498
- Get the padded representation of the normals.
506
+ Get the padded representation of the normals,
507
+ or None if there are no normals.
499
508
500
509
Returns:
501
510
tensor of normals of shape (N, max(P_n), 3).
502
511
"""
503
512
self ._compute_padded ()
504
513
return self ._normals_padded
505
514
506
- def features_padded (self ):
515
+ def features_padded (self ) -> Optional [ torch . Tensor ] :
507
516
"""
508
- Get the padded representation of the features.
517
+ Get the padded representation of the features,
518
+ or None if there are no features.
509
519
510
520
Returns:
511
521
tensor of features of shape (N, max(P_n), 3).
@@ -562,16 +572,18 @@ def _compute_padded(self, refresh: bool = False):
562
572
pad_value = 0.0 ,
563
573
equisized = self .equisized ,
564
574
)
565
- if self .normals_list () is not None :
575
+ normals_list = self .normals_list ()
576
+ if normals_list is not None :
566
577
self ._normals_padded = struct_utils .list_to_padded (
567
- self . normals_list () ,
578
+ normals_list ,
568
579
(self ._P , 3 ),
569
580
pad_value = 0.0 ,
570
581
equisized = self .equisized ,
571
582
)
572
- if self .features_list () is not None :
583
+ features_list = self .features_list ()
584
+ if features_list is not None :
573
585
self ._features_padded = struct_utils .list_to_padded (
574
- self . features_list () ,
586
+ features_list ,
575
587
(self ._P , self ._C ),
576
588
pad_value = 0.0 ,
577
589
equisized = self .equisized ,
@@ -772,10 +784,12 @@ def get_cloud(self, index: int):
772
784
)
773
785
points = self .points_list ()[index ]
774
786
normals , features = None , None
775
- if self .normals_list () is not None :
776
- normals = self .normals_list ()[index ]
777
- if self .features_list () is not None :
778
- features = self .features_list ()[index ]
787
+ normals_list = self .normals_list ()
788
+ if normals_list is not None :
789
+ normals = normals_list [index ]
790
+ features_list = self .features_list ()
791
+ if features_list is not None :
792
+ features = features_list [index ]
779
793
return points , normals , features
780
794
781
795
# TODO(nikhilar) Move function to a utils file.
@@ -1022,13 +1036,15 @@ def extend(self, N: int):
1022
1036
new_points_list , new_normals_list , new_features_list = [], None , None
1023
1037
for points in self .points_list ():
1024
1038
new_points_list .extend (points .clone () for _ in range (N ))
1025
- if self .normals_list () is not None :
1039
+ normals_list = self .normals_list ()
1040
+ if normals_list is not None :
1026
1041
new_normals_list = []
1027
- for normals in self . normals_list () :
1042
+ for normals in normals_list :
1028
1043
new_normals_list .extend (normals .clone () for _ in range (N ))
1029
- if self .features_list () is not None :
1044
+ features_list = self .features_list ()
1045
+ if features_list is not None :
1030
1046
new_features_list = []
1031
- for features in self . features_list () :
1047
+ for features in features_list :
1032
1048
new_features_list .extend (features .clone () for _ in range (N ))
1033
1049
return self .__class__ (
1034
1050
points = new_points_list , normals = new_normals_list , features = new_features_list
0 commit comments