@@ -412,7 +412,7 @@ def load(self, device="cpu"):
412
412
)
413
413
414
414
415
- def _parse_spatial_size (size , * , name = "size" ):
415
+ def _parse_canvas_size (size , * , name = "size" ):
416
416
if size == "random" :
417
417
raise ValueError ("This should never happen" )
418
418
elif isinstance (size , int ) and size > 0 :
@@ -467,12 +467,13 @@ def load(self, device):
467
467
468
468
@dataclasses .dataclass
469
469
class ImageLoader (TensorLoader ):
470
- spatial_size : Tuple [int , int ] = dataclasses .field (init = False )
470
+ canvas_size : Tuple [int , int ] = dataclasses .field (init = False )
471
471
num_channels : int = dataclasses .field (init = False )
472
472
memory_format : torch .memory_format = torch .contiguous_format
473
+ canvas_size : Tuple [int , int ] = dataclasses .field (init = False )
473
474
474
475
def __post_init__ (self ):
475
- self .spatial_size = self .shape [- 2 :]
476
+ self .canvas_size = self . canvas_size = self .shape [- 2 :]
476
477
self .num_channels = self .shape [- 3 ]
477
478
478
479
def load (self , device ):
@@ -538,7 +539,7 @@ def make_image_loader(
538
539
):
539
540
if not constant_alpha :
540
541
raise ValueError ("This should never happen" )
541
- size = _parse_spatial_size (size )
542
+ size = _parse_canvas_size (size )
542
543
num_channels = get_num_channels (color_space )
543
544
544
545
def fn (shape , dtype , device , memory_format ):
@@ -578,7 +579,7 @@ def make_image_loaders(
578
579
def make_image_loader_for_interpolation (
579
580
size = (233 , 147 ), * , color_space = "RGB" , dtype = torch .uint8 , memory_format = torch .contiguous_format
580
581
):
581
- size = _parse_spatial_size (size )
582
+ size = _parse_canvas_size (size )
582
583
num_channels = get_num_channels (color_space )
583
584
584
585
def fn (shape , dtype , device , memory_format ):
@@ -623,43 +624,20 @@ def make_image_loaders_for_interpolation(
623
624
class BoundingBoxesLoader (TensorLoader ):
624
625
format : datapoints .BoundingBoxFormat
625
626
spatial_size : Tuple [int , int ]
627
+ canvas_size : Tuple [int , int ] = dataclasses .field (init = False )
628
+
629
+ def __post_init__ (self ):
630
+ self .canvas_size = self .spatial_size
626
631
627
632
628
633
def make_bounding_box (
629
- size = None ,
634
+ canvas_size = DEFAULT_SIZE ,
630
635
* ,
631
636
format = datapoints .BoundingBoxFormat .XYXY ,
632
- spatial_size = None ,
633
637
batch_dims = (),
634
638
dtype = None ,
635
639
device = "cpu" ,
636
640
):
637
- """
638
- size: Size of the actual bounding box, i.e.
639
- - (box[3] - box[1], box[2] - box[0]) for XYXY
640
- - (H, W) for XYWH and CXCYWH
641
- spatial_size: Size of the reference object, e.g. an image. Corresponds to the .spatial_size attribute on
642
- returned datapoints.BoundingBoxes
643
-
644
- To generate a valid joint sample, you need to set spatial_size here to the same value as size on the other maker
645
- functions, e.g.
646
-
647
- .. code::
648
-
649
- image = make_image=(size=size)
650
- bounding_boxes = make_bounding_box(spatial_size=size)
651
- assert F.get_spatial_size(bounding_boxes) == F.get_spatial_size(image)
652
-
653
- For convenience, if both size and spatial_size are omitted, spatial_size defaults to the same value as size for all
654
- other maker functions, e.g.
655
-
656
- .. code::
657
-
658
- image = make_image=()
659
- bounding_boxes = make_bounding_box()
660
- assert F.get_spatial_size(bounding_boxes) == F.get_spatial_size(image)
661
- """
662
-
663
641
def sample_position (values , max_value ):
664
642
# We cannot use torch.randint directly here, because it only allows integer scalars as values for low and high.
665
643
# However, if we have batch_dims, we need tensors as limits.
@@ -668,28 +646,16 @@ def sample_position(values, max_value):
668
646
if isinstance (format , str ):
669
647
format = datapoints .BoundingBoxFormat [format ]
670
648
671
- if spatial_size is None :
672
- if size is None :
673
- spatial_size = DEFAULT_SIZE
674
- else :
675
- height , width = size
676
- height_margin , width_margin = torch .randint (10 , (2 ,)).tolist ()
677
- spatial_size = (height + height_margin , width + width_margin )
678
-
679
649
dtype = dtype or torch .float32
680
650
681
651
if any (dim == 0 for dim in batch_dims ):
682
652
return datapoints .BoundingBoxes (
683
- torch .empty (* batch_dims , 4 , dtype = dtype , device = device ), format = format , spatial_size = spatial_size
653
+ torch .empty (* batch_dims , 4 , dtype = dtype , device = device ), format = format , canvas_size = canvas_size
684
654
)
685
655
686
- if size is None :
687
- h , w = [torch .randint (1 , s , batch_dims ) for s in spatial_size ]
688
- else :
689
- h , w = [torch .full (batch_dims , s , dtype = torch .int ) for s in size ]
690
-
691
- y = sample_position (h , spatial_size [0 ])
692
- x = sample_position (w , spatial_size [1 ])
656
+ h , w = [torch .randint (1 , c , batch_dims ) for c in canvas_size ]
657
+ y = sample_position (h , canvas_size [0 ])
658
+ x = sample_position (w , canvas_size [1 ])
693
659
694
660
if format is datapoints .BoundingBoxFormat .XYWH :
695
661
parts = (x , y , w , h )
@@ -706,37 +672,37 @@ def sample_position(values, max_value):
706
672
raise ValueError (f"Format { format } is not supported" )
707
673
708
674
return datapoints .BoundingBoxes (
709
- torch .stack (parts , dim = - 1 ).to (dtype = dtype , device = device ), format = format , spatial_size = spatial_size
675
+ torch .stack (parts , dim = - 1 ).to (dtype = dtype , device = device ), format = format , canvas_size = canvas_size
710
676
)
711
677
712
678
713
- def make_bounding_box_loader (* , extra_dims = (), format , spatial_size = DEFAULT_PORTRAIT_SPATIAL_SIZE , dtype = torch .float32 ):
679
+ def make_bounding_box_loader (* , extra_dims = (), format , canvas_size = DEFAULT_PORTRAIT_SPATIAL_SIZE , dtype = torch .float32 ):
714
680
if isinstance (format , str ):
715
681
format = datapoints .BoundingBoxFormat [format ]
716
682
717
- spatial_size = _parse_spatial_size ( spatial_size , name = "spatial_size " )
683
+ canvas_size = _parse_canvas_size ( canvas_size , name = "canvas_size " )
718
684
719
685
def fn (shape , dtype , device ):
720
686
* batch_dims , num_coordinates = shape
721
687
if num_coordinates != 4 :
722
688
raise pytest .UsageError ()
723
689
724
690
return make_bounding_box (
725
- format = format , spatial_size = spatial_size , batch_dims = batch_dims , dtype = dtype , device = device
691
+ format = format , canvas_size = canvas_size , batch_dims = batch_dims , dtype = dtype , device = device
726
692
)
727
693
728
- return BoundingBoxesLoader (fn , shape = (* extra_dims , 4 ), dtype = dtype , format = format , spatial_size = spatial_size )
694
+ return BoundingBoxesLoader (fn , shape = (* extra_dims , 4 ), dtype = dtype , format = format , spatial_size = canvas_size )
729
695
730
696
731
697
def make_bounding_box_loaders (
732
698
* ,
733
699
extra_dims = DEFAULT_EXTRA_DIMS ,
734
700
formats = tuple (datapoints .BoundingBoxFormat ),
735
- spatial_size = DEFAULT_PORTRAIT_SPATIAL_SIZE ,
701
+ canvas_size = DEFAULT_PORTRAIT_SPATIAL_SIZE ,
736
702
dtypes = (torch .float32 , torch .float64 , torch .int64 ),
737
703
):
738
704
for params in combinations_grid (extra_dims = extra_dims , format = formats , dtype = dtypes ):
739
- yield make_bounding_box_loader (** params , spatial_size = spatial_size )
705
+ yield make_bounding_box_loader (** params , canvas_size = canvas_size )
740
706
741
707
742
708
make_bounding_boxes = from_loaders (make_bounding_box_loaders )
@@ -761,7 +727,7 @@ def make_detection_mask(size=DEFAULT_SIZE, *, num_objects=5, batch_dims=(), dtyp
761
727
762
728
def make_detection_mask_loader (size = DEFAULT_PORTRAIT_SPATIAL_SIZE , * , num_objects = 5 , extra_dims = (), dtype = torch .uint8 ):
763
729
# This produces "detection" masks, i.e. `(*, N, H, W)`, where `N` denotes the number of objects
764
- size = _parse_spatial_size (size )
730
+ size = _parse_canvas_size (size )
765
731
766
732
def fn (shape , dtype , device ):
767
733
* batch_dims , num_objects , height , width = shape
@@ -802,15 +768,15 @@ def make_segmentation_mask_loader(
802
768
size = DEFAULT_PORTRAIT_SPATIAL_SIZE , * , num_categories = 10 , extra_dims = (), dtype = torch .uint8
803
769
):
804
770
# This produces "segmentation" masks, i.e. `(*, H, W)`, where the category is encoded in the values
805
- spatial_size = _parse_spatial_size (size )
771
+ canvas_size = _parse_canvas_size (size )
806
772
807
773
def fn (shape , dtype , device ):
808
774
* batch_dims , height , width = shape
809
775
return make_segmentation_mask (
810
776
(height , width ), num_categories = num_categories , batch_dims = batch_dims , dtype = dtype , device = device
811
777
)
812
778
813
- return MaskLoader (fn , shape = (* extra_dims , * spatial_size ), dtype = dtype )
779
+ return MaskLoader (fn , shape = (* extra_dims , * canvas_size ), dtype = dtype )
814
780
815
781
816
782
def make_segmentation_mask_loaders (
@@ -860,7 +826,7 @@ def make_video_loader(
860
826
extra_dims = (),
861
827
dtype = torch .uint8 ,
862
828
):
863
- size = _parse_spatial_size (size )
829
+ size = _parse_canvas_size (size )
864
830
865
831
def fn (shape , dtype , device , memory_format ):
866
832
* batch_dims , num_frames , _ , height , width = shape
0 commit comments