From 59b301c3e2f3f6f7f96f0f39fb2a7f090399687e Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 14 Feb 2023 13:23:09 +0000 Subject: [PATCH 1/2] Let LinearTransformation return datapoints instead of tensors --- torchvision/prototype/transforms/_misc.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index 39d9dc103f4..19ad7a2f1b9 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -78,10 +78,7 @@ def _check_inputs(self, sample: Any) -> Any: def _transform( self, inpt: Union[datapoints.TensorImageType, datapoints.TensorVideoType], params: Dict[str, Any] - ) -> torch.Tensor: - # Image instance after linear transformation is not Image anymore due to unknown data range - # Thus we will return Tensor for input Image - + ) -> Any: shape = inpt.shape n = shape[-3] * shape[-2] * shape[-1] if n != self.transformation_matrix.shape[0]: @@ -97,11 +94,15 @@ def _transform( f"Got {inpt.device} vs {self.mean_vector.device}" ) - flat_tensor = inpt.reshape(-1, n) - self.mean_vector + flat_inpt = inpt.reshape(-1, n) - self.mean_vector - transformation_matrix = self.transformation_matrix.to(flat_tensor.dtype) - transformed_tensor = torch.mm(flat_tensor, transformation_matrix) - return transformed_tensor.reshape(shape) + transformation_matrix = self.transformation_matrix.to(flat_inpt.dtype) + output = torch.mm(flat_inpt, transformation_matrix) + output = output.reshape(shape) + + if isinstance(inpt, (datapoints.Image, datapoints.Video)): + output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type] + return output class Normalize(Transform): @@ -120,7 +121,7 @@ def _check_inputs(self, sample: Any) -> Any: def _transform( self, inpt: Union[datapoints.TensorImageType, datapoints.TensorVideoType], params: Dict[str, Any] - ) -> torch.Tensor: + ) -> Any: return F.normalize(inpt, mean=self.mean, std=self.std, inplace=self.inplace) From 243a5d03053361254b615ef86f82876a1ae23ea3 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 14 Feb 2023 13:35:09 +0000 Subject: [PATCH 2/2] Any --- torchvision/prototype/transforms/_misc.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index 19ad7a2f1b9..b398227b480 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -76,9 +76,7 @@ def _check_inputs(self, sample: Any) -> Any: if has_any(sample, PIL.Image.Image): raise TypeError("LinearTransformation does not work on PIL Images") - def _transform( - self, inpt: Union[datapoints.TensorImageType, datapoints.TensorVideoType], params: Dict[str, Any] - ) -> Any: + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: shape = inpt.shape n = shape[-3] * shape[-2] * shape[-1] if n != self.transformation_matrix.shape[0]: