@@ -238,7 +238,6 @@ def load(self, device):
238
238
239
239
@dataclasses .dataclass
240
240
class ImageLoader (TensorLoader ):
241
- color_space : datapoints .ColorSpace
242
241
spatial_size : Tuple [int , int ] = dataclasses .field (init = False )
243
242
num_channels : int = dataclasses .field (init = False )
244
243
@@ -248,10 +247,10 @@ def __post_init__(self):
248
247
249
248
250
249
NUM_CHANNELS_MAP = {
251
- datapoints . ColorSpace . GRAY : 1 ,
252
- datapoints . ColorSpace . GRAY_ALPHA : 2 ,
253
- datapoints . ColorSpace . RGB : 3 ,
254
- datapoints . ColorSpace . RGB_ALPHA : 4 ,
250
+ " GRAY" : 1 ,
251
+ " GRAY_ALPHA" : 2 ,
252
+ " RGB" : 3 ,
253
+ "RGBA" : 4 ,
255
254
}
256
255
257
256
@@ -265,7 +264,7 @@ def get_num_channels(color_space):
265
264
def make_image_loader (
266
265
size = "random" ,
267
266
* ,
268
- color_space = datapoints . ColorSpace . RGB ,
267
+ color_space = " RGB" ,
269
268
extra_dims = (),
270
269
dtype = torch .float32 ,
271
270
constant_alpha = True ,
@@ -276,11 +275,11 @@ def make_image_loader(
276
275
def fn (shape , dtype , device ):
277
276
max_value = get_max_value (dtype )
278
277
data = torch .testing .make_tensor (shape , low = 0 , high = max_value , dtype = dtype , device = device )
279
- if color_space in {datapoints . ColorSpace . GRAY_ALPHA , datapoints . ColorSpace . RGB_ALPHA } and constant_alpha :
278
+ if color_space in {" GRAY_ALPHA" , "RGBA" } and constant_alpha :
280
279
data [..., - 1 , :, :] = max_value
281
- return datapoints .Image (data , color_space = color_space )
280
+ return datapoints .Image (data )
282
281
283
- return ImageLoader (fn , shape = (* extra_dims , num_channels , * size ), dtype = dtype , color_space = color_space )
282
+ return ImageLoader (fn , shape = (* extra_dims , num_channels , * size ), dtype = dtype )
284
283
285
284
286
285
make_image = from_loader (make_image_loader )
@@ -290,10 +289,10 @@ def make_image_loaders(
290
289
* ,
291
290
sizes = DEFAULT_SPATIAL_SIZES ,
292
291
color_spaces = (
293
- datapoints . ColorSpace . GRAY ,
294
- datapoints . ColorSpace . GRAY_ALPHA ,
295
- datapoints . ColorSpace . RGB ,
296
- datapoints . ColorSpace . RGB_ALPHA ,
292
+ " GRAY" ,
293
+ " GRAY_ALPHA" ,
294
+ " RGB" ,
295
+ "RGBA" ,
297
296
),
298
297
extra_dims = DEFAULT_EXTRA_DIMS ,
299
298
dtypes = (torch .float32 , torch .uint8 ),
@@ -306,7 +305,7 @@ def make_image_loaders(
306
305
make_images = from_loaders (make_image_loaders )
307
306
308
307
309
- def make_image_loader_for_interpolation (size = "random" , * , color_space = datapoints . ColorSpace . RGB , dtype = torch .uint8 ):
308
+ def make_image_loader_for_interpolation (size = "random" , * , color_space = " RGB" , dtype = torch .uint8 ):
310
309
size = _parse_spatial_size (size )
311
310
num_channels = get_num_channels (color_space )
312
311
@@ -318,24 +317,24 @@ def fn(shape, dtype, device):
318
317
.resize ((width , height ))
319
318
.convert (
320
319
{
321
- datapoints . ColorSpace . GRAY : "L" ,
322
- datapoints . ColorSpace . GRAY_ALPHA : "LA" ,
323
- datapoints . ColorSpace . RGB : "RGB" ,
324
- datapoints . ColorSpace . RGB_ALPHA : "RGBA" ,
320
+ " GRAY" : "L" ,
321
+ " GRAY_ALPHA" : "LA" ,
322
+ " RGB" : "RGB" ,
323
+ "RGBA" : "RGBA" ,
325
324
}[color_space ]
326
325
)
327
326
)
328
327
329
328
image_tensor = convert_dtype_image_tensor (to_image_tensor (image_pil ).to (device = device ), dtype = dtype )
330
329
331
- return datapoints .Image (image_tensor , color_space = color_space )
330
+ return datapoints .Image (image_tensor )
332
331
333
- return ImageLoader (fn , shape = (num_channels , * size ), dtype = dtype , color_space = color_space )
332
+ return ImageLoader (fn , shape = (num_channels , * size ), dtype = dtype )
334
333
335
334
336
335
def make_image_loaders_for_interpolation (
337
336
sizes = ((233 , 147 ),),
338
- color_spaces = (datapoints . ColorSpace . RGB ,),
337
+ color_spaces = (" RGB" ,),
339
338
dtypes = (torch .uint8 ,),
340
339
):
341
340
for params in combinations_grid (size = sizes , color_space = color_spaces , dtype = dtypes ):
@@ -583,7 +582,7 @@ class VideoLoader(ImageLoader):
583
582
def make_video_loader (
584
583
size = "random" ,
585
584
* ,
586
- color_space = datapoints . ColorSpace . RGB ,
585
+ color_space = " RGB" ,
587
586
num_frames = "random" ,
588
587
extra_dims = (),
589
588
dtype = torch .uint8 ,
@@ -592,12 +591,10 @@ def make_video_loader(
592
591
num_frames = int (torch .randint (1 , 5 , ())) if num_frames == "random" else num_frames
593
592
594
593
def fn (shape , dtype , device ):
595
- video = make_image (size = shape [- 2 :], color_space = color_space , extra_dims = shape [:- 3 ], dtype = dtype , device = device )
596
- return datapoints .Video (video , color_space = color_space )
594
+ video = make_image (size = shape [- 2 :], extra_dims = shape [:- 3 ], dtype = dtype , device = device )
595
+ return datapoints .Video (video )
597
596
598
- return VideoLoader (
599
- fn , shape = (* extra_dims , num_frames , get_num_channels (color_space ), * size ), dtype = dtype , color_space = color_space
600
- )
597
+ return VideoLoader (fn , shape = (* extra_dims , num_frames , get_num_channels (color_space ), * size ), dtype = dtype )
601
598
602
599
603
600
make_video = from_loader (make_video_loader )
@@ -607,8 +604,8 @@ def make_video_loaders(
607
604
* ,
608
605
sizes = DEFAULT_SPATIAL_SIZES ,
609
606
color_spaces = (
610
- datapoints . ColorSpace . GRAY ,
611
- datapoints . ColorSpace . RGB ,
607
+ " GRAY" ,
608
+ " RGB" ,
612
609
),
613
610
num_frames = (1 , 0 , "random" ),
614
611
extra_dims = DEFAULT_EXTRA_DIMS ,
0 commit comments