@@ -184,13 +184,18 @@ def load(self, device="cpu"):
184
184
return args , kwargs
185
185
186
186
187
- DEFAULT_SQUARE_IMAGE_SIZE = 15
188
- DEFAULT_LANDSCAPE_IMAGE_SIZE = (7 , 33 )
189
- DEFAULT_PORTRAIT_IMAGE_SIZE = (31 , 9 )
190
- DEFAULT_IMAGE_SIZES = (DEFAULT_LANDSCAPE_IMAGE_SIZE , DEFAULT_PORTRAIT_IMAGE_SIZE , DEFAULT_SQUARE_IMAGE_SIZE , "random" )
187
+ DEFAULT_SQUARE_SPATIAL_SIZE = 15
188
+ DEFAULT_LANDSCAPE_SPATIAL_SIZE = (7 , 33 )
189
+ DEFAULT_PORTRAIT_SPATIAL_SIZE = (31 , 9 )
190
+ DEFAULT_SPATIAL_SIZES = (
191
+ DEFAULT_LANDSCAPE_SPATIAL_SIZE ,
192
+ DEFAULT_PORTRAIT_SPATIAL_SIZE ,
193
+ DEFAULT_SQUARE_SPATIAL_SIZE ,
194
+ "random" ,
195
+ )
191
196
192
197
193
- def _parse_image_size (size , * , name = "size" ):
198
+ def _parse_spatial_size (size , * , name = "size" ):
194
199
if size == "random" :
195
200
return tuple (torch .randint (15 , 33 , (2 ,)).tolist ())
196
201
elif isinstance (size , int ) and size > 0 :
@@ -246,11 +251,11 @@ def load(self, device):
246
251
@dataclasses .dataclass
247
252
class ImageLoader (TensorLoader ):
248
253
color_space : features .ColorSpace
249
- image_size : Tuple [int , int ] = dataclasses .field (init = False )
254
+ spatial_size : Tuple [int , int ] = dataclasses .field (init = False )
250
255
num_channels : int = dataclasses .field (init = False )
251
256
252
257
def __post_init__ (self ):
253
- self .image_size = self .shape [- 2 :]
258
+ self .spatial_size = self .shape [- 2 :]
254
259
self .num_channels = self .shape [- 3 ]
255
260
256
261
@@ -277,7 +282,7 @@ def make_image_loader(
277
282
dtype = torch .float32 ,
278
283
constant_alpha = True ,
279
284
):
280
- size = _parse_image_size (size )
285
+ size = _parse_spatial_size (size )
281
286
num_channels = get_num_channels (color_space )
282
287
283
288
def fn (shape , dtype , device ):
@@ -295,7 +300,7 @@ def fn(shape, dtype, device):
295
300
296
301
def make_image_loaders (
297
302
* ,
298
- sizes = DEFAULT_IMAGE_SIZES ,
303
+ sizes = DEFAULT_SPATIAL_SIZES ,
299
304
color_spaces = (
300
305
features .ColorSpace .GRAY ,
301
306
features .ColorSpace .GRAY_ALPHA ,
@@ -316,7 +321,7 @@ def make_image_loaders(
316
321
@dataclasses .dataclass
317
322
class BoundingBoxLoader (TensorLoader ):
318
323
format : features .BoundingBoxFormat
319
- image_size : Tuple [int , int ]
324
+ spatial_size : Tuple [int , int ]
320
325
321
326
322
327
def randint_with_tensor_bounds (arg1 , arg2 = None , ** kwargs ):
@@ -331,7 +336,7 @@ def randint_with_tensor_bounds(arg1, arg2=None, **kwargs):
331
336
).reshape (low .shape )
332
337
333
338
334
- def make_bounding_box_loader (* , extra_dims = (), format , image_size = "random" , dtype = torch .float32 ):
339
+ def make_bounding_box_loader (* , extra_dims = (), format , spatial_size = "random" , dtype = torch .float32 ):
335
340
if isinstance (format , str ):
336
341
format = features .BoundingBoxFormat [format ]
337
342
if format not in {
@@ -341,7 +346,7 @@ def make_bounding_box_loader(*, extra_dims=(), format, image_size="random", dtyp
341
346
}:
342
347
raise pytest .UsageError (f"Can't make bounding box in format { format } " )
343
348
344
- image_size = _parse_image_size ( image_size , name = "image_size " )
349
+ spatial_size = _parse_spatial_size ( spatial_size , name = "spatial_size " )
345
350
346
351
def fn (shape , dtype , device ):
347
352
* extra_dims , num_coordinates = shape
@@ -350,10 +355,10 @@ def fn(shape, dtype, device):
350
355
351
356
if any (dim == 0 for dim in extra_dims ):
352
357
return features .BoundingBox (
353
- torch .empty (* extra_dims , 4 , dtype = dtype , device = device ), format = format , spatial_size = image_size
358
+ torch .empty (* extra_dims , 4 , dtype = dtype , device = device ), format = format , spatial_size = spatial_size
354
359
)
355
360
356
- height , width = image_size
361
+ height , width = spatial_size
357
362
358
363
if format == features .BoundingBoxFormat .XYXY :
359
364
x1 = torch .randint (0 , width // 2 , extra_dims )
@@ -375,10 +380,10 @@ def fn(shape, dtype, device):
375
380
parts = (cx , cy , w , h )
376
381
377
382
return features .BoundingBox (
378
- torch .stack (parts , dim = - 1 ).to (dtype = dtype , device = device ), format = format , spatial_size = image_size
383
+ torch .stack (parts , dim = - 1 ).to (dtype = dtype , device = device ), format = format , spatial_size = spatial_size
379
384
)
380
385
381
- return BoundingBoxLoader (fn , shape = (* extra_dims , 4 ), dtype = dtype , format = format , image_size = image_size )
386
+ return BoundingBoxLoader (fn , shape = (* extra_dims , 4 ), dtype = dtype , format = format , spatial_size = spatial_size )
382
387
383
388
384
389
make_bounding_box = from_loader (make_bounding_box_loader )
@@ -388,11 +393,11 @@ def make_bounding_box_loaders(
388
393
* ,
389
394
extra_dims = DEFAULT_EXTRA_DIMS ,
390
395
formats = tuple (features .BoundingBoxFormat ),
391
- image_size = "random" ,
396
+ spatial_size = "random" ,
392
397
dtypes = (torch .float32 , torch .int64 ),
393
398
):
394
399
for params in combinations_grid (extra_dims = extra_dims , format = formats , dtype = dtypes ):
395
- yield make_bounding_box_loader (** params , image_size = image_size )
400
+ yield make_bounding_box_loader (** params , spatial_size = spatial_size )
396
401
397
402
398
403
make_bounding_boxes = from_loaders (make_bounding_box_loaders )
@@ -475,7 +480,7 @@ class MaskLoader(TensorLoader):
475
480
476
481
def make_detection_mask_loader (size = "random" , * , num_objects = "random" , extra_dims = (), dtype = torch .uint8 ):
477
482
# This produces "detection" masks, i.e. `(*, N, H, W)`, where `N` denotes the number of objects
478
- size = _parse_image_size (size )
483
+ size = _parse_spatial_size (size )
479
484
num_objects = int (torch .randint (1 , 11 , ())) if num_objects == "random" else num_objects
480
485
481
486
def fn (shape , dtype , device ):
@@ -489,7 +494,7 @@ def fn(shape, dtype, device):
489
494
490
495
491
496
def make_detection_mask_loaders (
492
- sizes = DEFAULT_IMAGE_SIZES ,
497
+ sizes = DEFAULT_SPATIAL_SIZES ,
493
498
num_objects = (1 , 0 , "random" ),
494
499
extra_dims = DEFAULT_EXTRA_DIMS ,
495
500
dtypes = (torch .uint8 ,),
@@ -503,7 +508,7 @@ def make_detection_mask_loaders(
503
508
504
509
def make_segmentation_mask_loader (size = "random" , * , num_categories = "random" , extra_dims = (), dtype = torch .uint8 ):
505
510
# This produces "segmentation" masks, i.e. `(*, H, W)`, where the category is encoded in the values
506
- size = _parse_image_size (size )
511
+ size = _parse_spatial_size (size )
507
512
num_categories = int (torch .randint (1 , 11 , ())) if num_categories == "random" else num_categories
508
513
509
514
def fn (shape , dtype , device ):
@@ -518,7 +523,7 @@ def fn(shape, dtype, device):
518
523
519
524
def make_segmentation_mask_loaders (
520
525
* ,
521
- sizes = DEFAULT_IMAGE_SIZES ,
526
+ sizes = DEFAULT_SPATIAL_SIZES ,
522
527
num_categories = (1 , 2 , "random" ),
523
528
extra_dims = DEFAULT_EXTRA_DIMS ,
524
529
dtypes = (torch .uint8 ,),
@@ -532,7 +537,7 @@ def make_segmentation_mask_loaders(
532
537
533
538
def make_mask_loaders (
534
539
* ,
535
- sizes = DEFAULT_IMAGE_SIZES ,
540
+ sizes = DEFAULT_SPATIAL_SIZES ,
536
541
num_objects = (1 , 0 , "random" ),
537
542
num_categories = (1 , 2 , "random" ),
538
543
extra_dims = DEFAULT_EXTRA_DIMS ,
@@ -559,7 +564,7 @@ def make_video_loader(
559
564
extra_dims = (),
560
565
dtype = torch .uint8 ,
561
566
):
562
- size = _parse_image_size (size )
567
+ size = _parse_spatial_size (size )
563
568
num_frames = int (torch .randint (1 , 5 , ())) if num_frames == "random" else num_frames
564
569
565
570
def fn (shape , dtype , device ):
@@ -576,7 +581,7 @@ def fn(shape, dtype, device):
576
581
577
582
def make_video_loaders (
578
583
* ,
579
- sizes = DEFAULT_IMAGE_SIZES ,
584
+ sizes = DEFAULT_SPATIAL_SIZES ,
580
585
color_spaces = (
581
586
features .ColorSpace .GRAY ,
582
587
features .ColorSpace .RGB ,
0 commit comments