From 40908977a7729ab335e0586d961755f2e9b1199b Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 22 Aug 2022 11:27:27 +0100 Subject: [PATCH 1/2] Add RGB to BGR in S3D presets --- torchvision/models/video/s3d.py | 1 + torchvision/transforms/_presets.py | 11 +++++++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/torchvision/models/video/s3d.py b/torchvision/models/video/s3d.py index f80d849683c..90861e57191 100644 --- a/torchvision/models/video/s3d.py +++ b/torchvision/models/video/s3d.py @@ -160,6 +160,7 @@ class S3D_Weights(WeightsEnum): resize_size=(256, 256), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + channel_order=(2, 1, 0), # RGB to BGR ), meta={ "min_size": (224, 224), diff --git a/torchvision/transforms/_presets.py b/torchvision/transforms/_presets.py index 33b94d01c9d..7e79ef6e95c 100644 --- a/torchvision/transforms/_presets.py +++ b/torchvision/transforms/_presets.py @@ -89,6 +89,7 @@ def __init__( mean: Tuple[float, ...] = (0.43216, 0.394666, 0.37645), std: Tuple[float, ...] = (0.22803, 0.22145, 0.216989), interpolation: InterpolationMode = InterpolationMode.BILINEAR, + channel_order: Optional[Tuple[int, int, int]] = None, ) -> None: super().__init__() self.crop_size = list(crop_size) @@ -96,6 +97,7 @@ def __init__( self.mean = list(mean) self.std = list(std) self.interpolation = interpolation + self.channel_order = channel_order def forward(self, vid: Tensor) -> Tensor: need_squeeze = False @@ -105,6 +107,8 @@ def forward(self, vid: Tensor) -> Tensor: N, T, C, H, W = vid.shape vid = vid.view(-1, C, H, W) + if self.channel_order is not None: + vid = vid[:, self.channel_order] vid = F.resize(vid, self.resize_size, interpolation=self.interpolation) vid = F.center_crop(vid, self.crop_size) vid = F.convert_image_dtype(vid, torch.float) @@ -124,12 +128,15 @@ def __repr__(self) -> str: format_string += f"\n mean={self.mean}" format_string += f"\n std={self.std}" format_string += f"\n interpolation={self.interpolation}" + format_string += f"\n channel_order={self.channel_order}" format_string += "\n)" return format_string def describe(self) -> str: - return ( - "Accepts batched ``(B, T, C, H, W)`` and single ``(T, C, H, W)`` video frame ``torch.Tensor`` objects. " + s = "Accepts batched ``(B, T, C, H, W)`` and single ``(T, C, H, W)`` video frame ``torch.Tensor`` objects. " + if self.channel_order: + s += f"Remaps the order within the channels dimension using ``channel_order={self.channel_order}``. " + s += ( f"The frames are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``, " f"followed by a central crop of ``crop_size={self.crop_size}``. Finally the values are first rescaled to " f"``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and ``std={self.std}``. Finally the output " From f2c9a1c570e7730b765125679abf477ce2a15c08 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 22 Aug 2022 11:56:44 +0100 Subject: [PATCH 2/2] Move channel reorder after the normalization. --- torchvision/transforms/_presets.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/torchvision/transforms/_presets.py b/torchvision/transforms/_presets.py index 7e79ef6e95c..1b41f68d672 100644 --- a/torchvision/transforms/_presets.py +++ b/torchvision/transforms/_presets.py @@ -107,12 +107,12 @@ def forward(self, vid: Tensor) -> Tensor: N, T, C, H, W = vid.shape vid = vid.view(-1, C, H, W) - if self.channel_order is not None: - vid = vid[:, self.channel_order] vid = F.resize(vid, self.resize_size, interpolation=self.interpolation) vid = F.center_crop(vid, self.crop_size) vid = F.convert_image_dtype(vid, torch.float) vid = F.normalize(vid, mean=self.mean, std=self.std) + if self.channel_order is not None: + vid = vid[:, self.channel_order] H, W = self.crop_size vid = vid.view(N, T, C, H, W) vid = vid.permute(0, 2, 1, 3, 4) # (N, T, C, H, W) => (N, C, T, H, W) @@ -133,15 +133,16 @@ def __repr__(self) -> str: return format_string def describe(self) -> str: - s = "Accepts batched ``(B, T, C, H, W)`` and single ``(T, C, H, W)`` video frame ``torch.Tensor`` objects. " - if self.channel_order: - s += f"Remaps the order within the channels dimension using ``channel_order={self.channel_order}``. " - s += ( + s = ( + "Accepts batched ``(B, T, C, H, W)`` and single ``(T, C, H, W)`` video frame ``torch.Tensor`` objects. " f"The frames are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``, " f"followed by a central crop of ``crop_size={self.crop_size}``. Finally the values are first rescaled to " - f"``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and ``std={self.std}``. Finally the output " - "dimensions are permuted to ``(..., C, T, H, W)`` tensors." + f"``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and ``std={self.std}``. " ) + if self.channel_order is not None: + s += f"Remaps the order within the channels dimension using ``channel_order={self.channel_order}``. " + s += "Finally the output dimensions are permuted to ``(..., C, T, H, W)`` tensors." + return s class SemanticSegmentation(nn.Module):