@@ -1935,7 +1935,14 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
1935
1935
@pytest .mark .parametrize (
1936
1936
"labels_getter" , ("default" , "labels" , lambda inputs : inputs ["labels" ], None , lambda inputs : None )
1937
1937
)
1938
- def test_sanitize_bounding_boxes (min_size , labels_getter ):
1938
+ @pytest .mark .parametrize ("sample_type" , (tuple , dict ))
1939
+ def test_sanitize_bounding_boxes (min_size , labels_getter , sample_type ):
1940
+
1941
+ if sample_type is tuple and not isinstance (labels_getter , str ):
1942
+ # The "lambda inputs: inputs["labels"]" labels_getter used in this test
1943
+ # doesn't work if the input is a tuple.
1944
+ return
1945
+
1939
1946
H , W = 256 , 128
1940
1947
1941
1948
boxes_and_validity = [
@@ -1970,35 +1977,56 @@ def test_sanitize_bounding_boxes(min_size, labels_getter):
1970
1977
)
1971
1978
1972
1979
masks = datapoints .Mask (torch .randint (0 , 2 , size = (boxes .shape [0 ], H , W )))
1973
-
1980
+ whatever = torch .rand (10 )
1981
+ input_img = torch .randint (0 , 256 , size = (1 , 3 , H , W ), dtype = torch .uint8 )
1974
1982
sample = {
1975
- "image" : torch . randint ( 0 , 256 , size = ( 1 , 3 , H , W ), dtype = torch . uint8 ) ,
1983
+ "image" : input_img ,
1976
1984
"labels" : labels ,
1977
1985
"boxes" : boxes ,
1978
- "whatever" : torch . rand ( 10 ) ,
1986
+ "whatever" : whatever ,
1979
1987
"None" : None ,
1980
1988
"masks" : masks ,
1981
1989
}
1982
1990
1991
+ if sample_type is tuple :
1992
+ img = sample .pop ("image" )
1993
+ sample = (img , sample )
1994
+
1983
1995
out = transforms .SanitizeBoundingBoxes (min_size = min_size , labels_getter = labels_getter )(sample )
1984
1996
1985
- assert out ["image" ] is sample ["image" ]
1986
- assert out ["whatever" ] is sample ["whatever" ]
1997
+ if sample_type is tuple :
1998
+ out_image = out [0 ]
1999
+ out_labels = out [1 ]["labels" ]
2000
+ out_boxes = out [1 ]["boxes" ]
2001
+ out_masks = out [1 ]["masks" ]
2002
+ out_whatever = out [1 ]["whatever" ]
2003
+ else :
2004
+ out_image = out ["image" ]
2005
+ out_labels = out ["labels" ]
2006
+ out_boxes = out ["boxes" ]
2007
+ out_masks = out ["masks" ]
2008
+ out_whatever = out ["whatever" ]
2009
+
2010
+ assert out_image is input_img
2011
+ assert out_whatever is whatever
1987
2012
1988
2013
if labels_getter is None or (callable (labels_getter ) and labels_getter ({"labels" : "blah" }) is None ):
1989
- assert out [ "labels" ] is sample [ " labels" ]
2014
+ assert out_labels is labels
1990
2015
else :
1991
- assert isinstance (out [ "labels" ] , torch .Tensor )
1992
- assert out [ "boxes" ] .shape [0 ] == out [ "labels" ] .shape [0 ] == out [ "masks" ] .shape [0 ]
2016
+ assert isinstance (out_labels , torch .Tensor )
2017
+ assert out_boxes .shape [0 ] == out_labels .shape [0 ] == out_masks .shape [0 ]
1993
2018
# This works because we conveniently set labels to arange(num_boxes)
1994
- assert out [ "labels" ] .tolist () == valid_indices
2019
+ assert out_labels .tolist () == valid_indices
1995
2020
1996
2021
1997
2022
@pytest .mark .parametrize ("key" , ("labels" , "LABELS" , "LaBeL" , "SOME_WEIRD_KEY_THAT_HAS_LABeL_IN_IT" ))
1998
- def test_sanitize_bounding_boxes_default_heuristic (key ):
2023
+ @pytest .mark .parametrize ("sample_type" , (tuple , dict ))
2024
+ def test_sanitize_bounding_boxes_default_heuristic (key , sample_type ):
1999
2025
labels = torch .arange (10 )
2000
- d = {key : labels }
2001
- assert transforms .SanitizeBoundingBoxes ._find_labels_default_heuristic (d ) is labels
2026
+ sample = {key : labels , "another_key" : "whatever" }
2027
+ if sample_type is tuple :
2028
+ sample = (None , sample , "whatever_again" )
2029
+ assert transforms .SanitizeBoundingBoxes ._find_labels_default_heuristic (sample ) is labels
2002
2030
2003
2031
if key .lower () != "labels" :
2004
2032
# If "labels" is in the dict (case-insensitive),
0 commit comments