@@ -423,7 +423,7 @@ def load(self, device="cpu"):
423
423
)
424
424
425
425
426
- def _parse_canvas_size (size , * , name = "size" ):
426
+ def _parse_size (size , * , name = "size" ):
427
427
if size == "random" :
428
428
raise ValueError ("This should never happen" )
429
429
elif isinstance (size , int ) and size > 0 :
@@ -478,13 +478,13 @@ def load(self, device):
478
478
479
479
@dataclasses .dataclass
480
480
class ImageLoader (TensorLoader ):
481
- canvas_size : Tuple [int , int ] = dataclasses .field (init = False )
481
+ spatial_size : Tuple [int , int ] = dataclasses .field (init = False )
482
482
num_channels : int = dataclasses .field (init = False )
483
483
memory_format : torch .memory_format = torch .contiguous_format
484
484
canvas_size : Tuple [int , int ] = dataclasses .field (init = False )
485
485
486
486
def __post_init__ (self ):
487
- self .canvas_size = self .canvas_size = self .shape [- 2 :]
487
+ self .spatial_size = self .canvas_size = self .shape [- 2 :]
488
488
self .num_channels = self .shape [- 3 ]
489
489
490
490
def load (self , device ):
@@ -550,7 +550,7 @@ def make_image_loader(
550
550
):
551
551
if not constant_alpha :
552
552
raise ValueError ("This should never happen" )
553
- size = _parse_canvas_size (size )
553
+ size = _parse_size (size )
554
554
num_channels = get_num_channels (color_space )
555
555
556
556
def fn (shape , dtype , device , memory_format ):
@@ -590,7 +590,7 @@ def make_image_loaders(
590
590
def make_image_loader_for_interpolation (
591
591
size = (233 , 147 ), * , color_space = "RGB" , dtype = torch .uint8 , memory_format = torch .contiguous_format
592
592
):
593
- size = _parse_canvas_size (size )
593
+ size = _parse_size (size )
594
594
num_channels = get_num_channels (color_space )
595
595
596
596
def fn (shape , dtype , device , memory_format ):
@@ -687,33 +687,33 @@ def sample_position(values, max_value):
687
687
)
688
688
689
689
690
- def make_bounding_box_loader (* , extra_dims = (), format , canvas_size = DEFAULT_PORTRAIT_SPATIAL_SIZE , dtype = torch .float32 ):
690
+ def make_bounding_box_loader (* , extra_dims = (), format , spatial_size = DEFAULT_PORTRAIT_SPATIAL_SIZE , dtype = torch .float32 ):
691
691
if isinstance (format , str ):
692
692
format = datapoints .BoundingBoxFormat [format ]
693
693
694
- canvas_size = _parse_canvas_size ( canvas_size , name = "canvas_size" )
694
+ spatial_size = _parse_size ( spatial_size , name = "canvas_size" )
695
695
696
696
def fn (shape , dtype , device ):
697
697
* batch_dims , num_coordinates = shape
698
698
if num_coordinates != 4 :
699
699
raise pytest .UsageError ()
700
700
701
701
return make_bounding_box (
702
- format = format , canvas_size = canvas_size , batch_dims = batch_dims , dtype = dtype , device = device
702
+ format = format , canvas_size = spatial_size , batch_dims = batch_dims , dtype = dtype , device = device
703
703
)
704
704
705
- return BoundingBoxesLoader (fn , shape = (* extra_dims , 4 ), dtype = dtype , format = format , spatial_size = canvas_size )
705
+ return BoundingBoxesLoader (fn , shape = (* extra_dims , 4 ), dtype = dtype , format = format , spatial_size = spatial_size )
706
706
707
707
708
708
def make_bounding_box_loaders (
709
709
* ,
710
710
extra_dims = DEFAULT_EXTRA_DIMS ,
711
711
formats = tuple (datapoints .BoundingBoxFormat ),
712
- canvas_size = DEFAULT_PORTRAIT_SPATIAL_SIZE ,
712
+ spatial_size = DEFAULT_PORTRAIT_SPATIAL_SIZE ,
713
713
dtypes = (torch .float32 , torch .float64 , torch .int64 ),
714
714
):
715
715
for params in combinations_grid (extra_dims = extra_dims , format = formats , dtype = dtypes ):
716
- yield make_bounding_box_loader (** params , canvas_size = canvas_size )
716
+ yield make_bounding_box_loader (** params , spatial_size = spatial_size )
717
717
718
718
719
719
make_bounding_boxes = from_loaders (make_bounding_box_loaders )
@@ -738,7 +738,7 @@ def make_detection_mask(size=DEFAULT_SIZE, *, num_objects=5, batch_dims=(), dtyp
738
738
739
739
def make_detection_mask_loader (size = DEFAULT_PORTRAIT_SPATIAL_SIZE , * , num_objects = 5 , extra_dims = (), dtype = torch .uint8 ):
740
740
# This produces "detection" masks, i.e. `(*, N, H, W)`, where `N` denotes the number of objects
741
- size = _parse_canvas_size (size )
741
+ size = _parse_size (size )
742
742
743
743
def fn (shape , dtype , device ):
744
744
* batch_dims , num_objects , height , width = shape
@@ -779,15 +779,15 @@ def make_segmentation_mask_loader(
779
779
size = DEFAULT_PORTRAIT_SPATIAL_SIZE , * , num_categories = 10 , extra_dims = (), dtype = torch .uint8
780
780
):
781
781
# This produces "segmentation" masks, i.e. `(*, H, W)`, where the category is encoded in the values
782
- canvas_size = _parse_canvas_size (size )
782
+ size = _parse_size (size )
783
783
784
784
def fn (shape , dtype , device ):
785
785
* batch_dims , height , width = shape
786
786
return make_segmentation_mask (
787
787
(height , width ), num_categories = num_categories , batch_dims = batch_dims , dtype = dtype , device = device
788
788
)
789
789
790
- return MaskLoader (fn , shape = (* extra_dims , * canvas_size ), dtype = dtype )
790
+ return MaskLoader (fn , shape = (* extra_dims , * size ), dtype = dtype )
791
791
792
792
793
793
def make_segmentation_mask_loaders (
@@ -841,7 +841,7 @@ def make_video_loader(
841
841
extra_dims = (),
842
842
dtype = torch .uint8 ,
843
843
):
844
- size = _parse_canvas_size (size )
844
+ size = _parse_size (size )
845
845
846
846
def fn (shape , dtype , device , memory_format ):
847
847
* batch_dims , num_frames , _ , height , width = shape
0 commit comments