diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index 39d9dc103f4..b398227b480 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -76,12 +76,7 @@ def _check_inputs(self, sample: Any) -> Any: if has_any(sample, PIL.Image.Image): raise TypeError("LinearTransformation does not work on PIL Images") - def _transform( - self, inpt: Union[datapoints.TensorImageType, datapoints.TensorVideoType], params: Dict[str, Any] - ) -> torch.Tensor: - # Image instance after linear transformation is not Image anymore due to unknown data range - # Thus we will return Tensor for input Image - + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: shape = inpt.shape n = shape[-3] * shape[-2] * shape[-1] if n != self.transformation_matrix.shape[0]: @@ -97,11 +92,15 @@ def _transform( f"Got {inpt.device} vs {self.mean_vector.device}" ) - flat_tensor = inpt.reshape(-1, n) - self.mean_vector + flat_inpt = inpt.reshape(-1, n) - self.mean_vector + + transformation_matrix = self.transformation_matrix.to(flat_inpt.dtype) + output = torch.mm(flat_inpt, transformation_matrix) + output = output.reshape(shape) - transformation_matrix = self.transformation_matrix.to(flat_tensor.dtype) - transformed_tensor = torch.mm(flat_tensor, transformation_matrix) - return transformed_tensor.reshape(shape) + if isinstance(inpt, (datapoints.Image, datapoints.Video)): + output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type] + return output class Normalize(Transform): @@ -120,7 +119,7 @@ def _check_inputs(self, sample: Any) -> Any: def _transform( self, inpt: Union[datapoints.TensorImageType, datapoints.TensorVideoType], params: Dict[str, Any] - ) -> torch.Tensor: + ) -> Any: return F.normalize(inpt, mean=self.mean, std=self.std, inplace=self.inplace)