Skip to content

Commit 5dd9594

Browse files
authored
Remove color_space metadata and ConvertColorSpace() transform (#7120)
1 parent c206a47 commit 5dd9594

16 files changed

+106
-506
lines changed

test/prototype_common_utils.py

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,6 @@ def load(self, device):
238238

239239
@dataclasses.dataclass
240240
class ImageLoader(TensorLoader):
241-
color_space: datapoints.ColorSpace
242241
spatial_size: Tuple[int, int] = dataclasses.field(init=False)
243242
num_channels: int = dataclasses.field(init=False)
244243

@@ -248,10 +247,10 @@ def __post_init__(self):
248247

249248

250249
NUM_CHANNELS_MAP = {
251-
datapoints.ColorSpace.GRAY: 1,
252-
datapoints.ColorSpace.GRAY_ALPHA: 2,
253-
datapoints.ColorSpace.RGB: 3,
254-
datapoints.ColorSpace.RGB_ALPHA: 4,
250+
"GRAY": 1,
251+
"GRAY_ALPHA": 2,
252+
"RGB": 3,
253+
"RGBA": 4,
255254
}
256255

257256

@@ -265,7 +264,7 @@ def get_num_channels(color_space):
265264
def make_image_loader(
266265
size="random",
267266
*,
268-
color_space=datapoints.ColorSpace.RGB,
267+
color_space="RGB",
269268
extra_dims=(),
270269
dtype=torch.float32,
271270
constant_alpha=True,
@@ -276,11 +275,11 @@ def make_image_loader(
276275
def fn(shape, dtype, device):
277276
max_value = get_max_value(dtype)
278277
data = torch.testing.make_tensor(shape, low=0, high=max_value, dtype=dtype, device=device)
279-
if color_space in {datapoints.ColorSpace.GRAY_ALPHA, datapoints.ColorSpace.RGB_ALPHA} and constant_alpha:
278+
if color_space in {"GRAY_ALPHA", "RGBA"} and constant_alpha:
280279
data[..., -1, :, :] = max_value
281-
return datapoints.Image(data, color_space=color_space)
280+
return datapoints.Image(data)
282281

283-
return ImageLoader(fn, shape=(*extra_dims, num_channels, *size), dtype=dtype, color_space=color_space)
282+
return ImageLoader(fn, shape=(*extra_dims, num_channels, *size), dtype=dtype)
284283

285284

286285
make_image = from_loader(make_image_loader)
@@ -290,10 +289,10 @@ def make_image_loaders(
290289
*,
291290
sizes=DEFAULT_SPATIAL_SIZES,
292291
color_spaces=(
293-
datapoints.ColorSpace.GRAY,
294-
datapoints.ColorSpace.GRAY_ALPHA,
295-
datapoints.ColorSpace.RGB,
296-
datapoints.ColorSpace.RGB_ALPHA,
292+
"GRAY",
293+
"GRAY_ALPHA",
294+
"RGB",
295+
"RGBA",
297296
),
298297
extra_dims=DEFAULT_EXTRA_DIMS,
299298
dtypes=(torch.float32, torch.uint8),
@@ -306,7 +305,7 @@ def make_image_loaders(
306305
make_images = from_loaders(make_image_loaders)
307306

308307

309-
def make_image_loader_for_interpolation(size="random", *, color_space=datapoints.ColorSpace.RGB, dtype=torch.uint8):
308+
def make_image_loader_for_interpolation(size="random", *, color_space="RGB", dtype=torch.uint8):
310309
size = _parse_spatial_size(size)
311310
num_channels = get_num_channels(color_space)
312311

@@ -318,24 +317,24 @@ def fn(shape, dtype, device):
318317
.resize((width, height))
319318
.convert(
320319
{
321-
datapoints.ColorSpace.GRAY: "L",
322-
datapoints.ColorSpace.GRAY_ALPHA: "LA",
323-
datapoints.ColorSpace.RGB: "RGB",
324-
datapoints.ColorSpace.RGB_ALPHA: "RGBA",
320+
"GRAY": "L",
321+
"GRAY_ALPHA": "LA",
322+
"RGB": "RGB",
323+
"RGBA": "RGBA",
325324
}[color_space]
326325
)
327326
)
328327

329328
image_tensor = convert_dtype_image_tensor(to_image_tensor(image_pil).to(device=device), dtype=dtype)
330329

331-
return datapoints.Image(image_tensor, color_space=color_space)
330+
return datapoints.Image(image_tensor)
332331

333-
return ImageLoader(fn, shape=(num_channels, *size), dtype=dtype, color_space=color_space)
332+
return ImageLoader(fn, shape=(num_channels, *size), dtype=dtype)
334333

335334

336335
def make_image_loaders_for_interpolation(
337336
sizes=((233, 147),),
338-
color_spaces=(datapoints.ColorSpace.RGB,),
337+
color_spaces=("RGB",),
339338
dtypes=(torch.uint8,),
340339
):
341340
for params in combinations_grid(size=sizes, color_space=color_spaces, dtype=dtypes):
@@ -583,7 +582,7 @@ class VideoLoader(ImageLoader):
583582
def make_video_loader(
584583
size="random",
585584
*,
586-
color_space=datapoints.ColorSpace.RGB,
585+
color_space="RGB",
587586
num_frames="random",
588587
extra_dims=(),
589588
dtype=torch.uint8,
@@ -592,12 +591,10 @@ def make_video_loader(
592591
num_frames = int(torch.randint(1, 5, ())) if num_frames == "random" else num_frames
593592

594593
def fn(shape, dtype, device):
595-
video = make_image(size=shape[-2:], color_space=color_space, extra_dims=shape[:-3], dtype=dtype, device=device)
596-
return datapoints.Video(video, color_space=color_space)
594+
video = make_image(size=shape[-2:], extra_dims=shape[:-3], dtype=dtype, device=device)
595+
return datapoints.Video(video)
597596

598-
return VideoLoader(
599-
fn, shape=(*extra_dims, num_frames, get_num_channels(color_space), *size), dtype=dtype, color_space=color_space
600-
)
597+
return VideoLoader(fn, shape=(*extra_dims, num_frames, get_num_channels(color_space), *size), dtype=dtype)
601598

602599

603600
make_video = from_loader(make_video_loader)
@@ -607,8 +604,8 @@ def make_video_loaders(
607604
*,
608605
sizes=DEFAULT_SPATIAL_SIZES,
609606
color_spaces=(
610-
datapoints.ColorSpace.GRAY,
611-
datapoints.ColorSpace.RGB,
607+
"GRAY",
608+
"RGB",
612609
),
613610
num_frames=(1, 0, "random"),
614611
extra_dims=DEFAULT_EXTRA_DIMS,

0 commit comments

Comments
 (0)