diff --git a/docs/source/conf.py b/docs/source/conf.py index 258bbf6b5f2..b0fe63cb288 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -366,6 +366,11 @@ def inject_weight_metadata(app, what, name, obj, options, lines): lines += [".. table::", ""] lines += textwrap.indent(table, " " * 4).split("\n") lines.append("") + lines.append( + f"The inference transforms are available at ``{str(field)}.transforms`` and " + f"perform the following operations: {field.transforms().describe()}" + ) + lines.append("") def generate_weights_table(module, table_name, metrics, include_patterns=None, exclude_patterns=None): diff --git a/torchvision/transforms/_presets.py b/torchvision/transforms/_presets.py index 4d503f44cc5..b009d45f1a4 100644 --- a/torchvision/transforms/_presets.py +++ b/torchvision/transforms/_presets.py @@ -25,6 +25,12 @@ def forward(self, img: Tensor) -> Tensor: img = F.pil_to_tensor(img) return F.convert_image_dtype(img, torch.float) + def __repr__(self) -> str: + return self.__class__.__name__ + "()" + + def describe(self) -> str: + return "The images are rescaled to ``[0.0, 1.0]``." + class ImageClassification(nn.Module): def __init__( @@ -37,21 +43,38 @@ def __init__( interpolation: InterpolationMode = InterpolationMode.BILINEAR, ) -> None: super().__init__() - self._crop_size = [crop_size] - self._size = [resize_size] - self._mean = list(mean) - self._std = list(std) - self._interpolation = interpolation + self.crop_size = [crop_size] + self.resize_size = [resize_size] + self.mean = list(mean) + self.std = list(std) + self.interpolation = interpolation def forward(self, img: Tensor) -> Tensor: - img = F.resize(img, self._size, interpolation=self._interpolation) - img = F.center_crop(img, self._crop_size) + img = F.resize(img, self.resize_size, interpolation=self.interpolation) + img = F.center_crop(img, self.crop_size) if not isinstance(img, Tensor): img = F.pil_to_tensor(img) img = F.convert_image_dtype(img, torch.float) - img = F.normalize(img, mean=self._mean, std=self._std) + img = F.normalize(img, mean=self.mean, std=self.std) return img + def __repr__(self) -> str: + format_string = self.__class__.__name__ + "(" + format_string += f"\n crop_size={self.crop_size}" + format_string += f"\n resize_size={self.resize_size}" + format_string += f"\n mean={self.mean}" + format_string += f"\n std={self.std}" + format_string += f"\n interpolation={self.interpolation}" + format_string += "\n)" + return format_string + + def describe(self) -> str: + return ( + f"The images are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``, " + f"followed by a central crop of ``crop_size={self.crop_size}``. Then the values are rescaled to " + f"``[0.0, 1.0]`` and normalized using ``mean={self.mean}`` and ``std={self.std}``." + ) + class VideoClassification(nn.Module): def __init__( @@ -64,11 +87,11 @@ def __init__( interpolation: InterpolationMode = InterpolationMode.BILINEAR, ) -> None: super().__init__() - self._crop_size = list(crop_size) - self._size = list(resize_size) - self._mean = list(mean) - self._std = list(std) - self._interpolation = interpolation + self.crop_size = list(crop_size) + self.resize_size = list(resize_size) + self.mean = list(mean) + self.std = list(std) + self.interpolation = interpolation def forward(self, vid: Tensor) -> Tensor: need_squeeze = False @@ -79,11 +102,11 @@ def forward(self, vid: Tensor) -> Tensor: vid = vid.permute(0, 1, 4, 2, 3) # (N, T, H, W, C) => (N, T, C, H, W) N, T, C, H, W = vid.shape vid = vid.view(-1, C, H, W) - vid = F.resize(vid, self._size, interpolation=self._interpolation) - vid = F.center_crop(vid, self._crop_size) + vid = F.resize(vid, self.resize_size, interpolation=self.interpolation) + vid = F.center_crop(vid, self.crop_size) vid = F.convert_image_dtype(vid, torch.float) - vid = F.normalize(vid, mean=self._mean, std=self._std) - H, W = self._crop_size + vid = F.normalize(vid, mean=self.mean, std=self.std) + H, W = self.crop_size vid = vid.view(N, T, C, H, W) vid = vid.permute(0, 2, 1, 3, 4) # (N, T, C, H, W) => (N, C, T, H, W) @@ -91,6 +114,23 @@ def forward(self, vid: Tensor) -> Tensor: vid = vid.squeeze(dim=0) return vid + def __repr__(self) -> str: + format_string = self.__class__.__name__ + "(" + format_string += f"\n crop_size={self.crop_size}" + format_string += f"\n resize_size={self.resize_size}" + format_string += f"\n mean={self.mean}" + format_string += f"\n std={self.std}" + format_string += f"\n interpolation={self.interpolation}" + format_string += "\n)" + return format_string + + def describe(self) -> str: + return ( + f"The video frames are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``, " + f"followed by a central crop of ``crop_size={self.crop_size}``. Then the values are rescaled to " + f"``[0.0, 1.0]`` and normalized using ``mean={self.mean}`` and ``std={self.std}``." + ) + class SemanticSegmentation(nn.Module): def __init__( @@ -102,20 +142,35 @@ def __init__( interpolation: InterpolationMode = InterpolationMode.BILINEAR, ) -> None: super().__init__() - self._size = [resize_size] if resize_size is not None else None - self._mean = list(mean) - self._std = list(std) - self._interpolation = interpolation + self.resize_size = [resize_size] if resize_size is not None else None + self.mean = list(mean) + self.std = list(std) + self.interpolation = interpolation def forward(self, img: Tensor) -> Tensor: - if isinstance(self._size, list): - img = F.resize(img, self._size, interpolation=self._interpolation) + if isinstance(self.resize_size, list): + img = F.resize(img, self.resize_size, interpolation=self.interpolation) if not isinstance(img, Tensor): img = F.pil_to_tensor(img) img = F.convert_image_dtype(img, torch.float) - img = F.normalize(img, mean=self._mean, std=self._std) + img = F.normalize(img, mean=self.mean, std=self.std) return img + def __repr__(self) -> str: + format_string = self.__class__.__name__ + "(" + format_string += f"\n resize_size={self.resize_size}" + format_string += f"\n mean={self.mean}" + format_string += f"\n std={self.std}" + format_string += f"\n interpolation={self.interpolation}" + format_string += "\n)" + return format_string + + def describe(self) -> str: + return ( + f"The images are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``. " + f"Then the values are rescaled to ``[0.0, 1.0]`` and normalized using ``mean={self.mean}`` and ``std={self.std}``." + ) + class OpticalFlow(nn.Module): def forward(self, img1: Tensor, img2: Tensor) -> Tuple[Tensor, Tensor]: @@ -135,3 +190,9 @@ def forward(self, img1: Tensor, img2: Tensor) -> Tuple[Tensor, Tensor]: img2 = img2.contiguous() return img1, img2 + + def __repr__(self) -> str: + return self.__class__.__name__ + "()" + + def describe(self) -> str: + return "The images are rescaled to ``[-1.0, 1.0]``."