@@ -884,7 +884,8 @@ def test_camera_class_init(self):
884
884
self .assertTrue (new_cam .device == device )
885
885
886
886
def test_getitem (self ):
887
- R_matrix = torch .randn ((6 , 3 , 3 ))
887
+ N_CAMERAS = 6
888
+ R_matrix = torch .randn ((N_CAMERAS , 3 , 3 ))
888
889
cam = FoVPerspectiveCameras (znear = 10.0 , zfar = 100.0 , R = R_matrix )
889
890
890
891
# Check get item returns an instance of the same class
@@ -908,22 +909,39 @@ def test_getitem(self):
908
909
self .assertClose (c012 .R , R_matrix [0 :3 , ...])
909
910
910
911
# Check torch.LongTensor index
911
- index = torch .tensor ([1 , 3 , 5 ], dtype = torch .int64 )
912
+ SLICE = [1 , 3 , 5 ]
913
+ index = torch .tensor (SLICE , dtype = torch .int64 )
912
914
c135 = cam [index ]
913
915
self .assertEqual (len (c135 ), 3 )
914
916
self .assertClose (c135 .zfar , torch .tensor ([100.0 ] * 3 ))
915
917
self .assertClose (c135 .znear , torch .tensor ([10.0 ] * 3 ))
916
- self .assertClose (c135 .R , R_matrix [[1 , 3 , 5 ], ...])
918
+ self .assertClose (c135 .R , R_matrix [SLICE , ...])
919
+
920
+ # Check torch.BoolTensor index
921
+ bool_slice = [i in SLICE for i in range (N_CAMERAS )]
922
+ index = torch .tensor (bool_slice , dtype = torch .bool )
923
+ c135 = cam [index ]
924
+ self .assertEqual (len (c135 ), 3 )
925
+ self .assertClose (c135 .zfar , torch .tensor ([100.0 ] * 3 ))
926
+ self .assertClose (c135 .znear , torch .tensor ([10.0 ] * 3 ))
927
+ self .assertClose (c135 .R , R_matrix [SLICE , ...])
917
928
918
929
# Check errors with get item
919
930
with self .assertRaisesRegex (ValueError , "out of bounds" ):
920
- cam [6 ]
931
+ cam [N_CAMERAS ]
932
+
933
+ with self .assertRaisesRegex (ValueError , "does not match cameras" ):
934
+ index = torch .tensor ([1 , 0 , 1 ], dtype = torch .bool )
935
+ cam [index ]
921
936
922
937
with self .assertRaisesRegex (ValueError , "Invalid index type" ):
923
938
cam [slice (0 , 1 )]
924
939
925
940
with self .assertRaisesRegex (ValueError , "Invalid index type" ):
926
- index = torch .tensor ([1 , 3 , 5 ], dtype = torch .float32 )
941
+ cam [[True , False ]]
942
+
943
+ with self .assertRaisesRegex (ValueError , "Invalid index type" ):
944
+ index = torch .tensor (SLICE , dtype = torch .float32 )
927
945
cam [index ]
928
946
929
947
def test_get_full_transform (self ):
0 commit comments