Skip to content

Commit ac6942e

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] Let LinearTransformation return datapoints instead of tensors (#7244)
Reviewed By: vmoens Differential Revision: D44416585 fbshipit-source-id: 8f61172e676563ad7eefbdb294eb76dab6dee2eb
1 parent 18eabd4 commit ac6942e

File tree

1 file changed

+10
-11
lines changed
  • torchvision/prototype/transforms

1 file changed

+10
-11
lines changed

torchvision/prototype/transforms/_misc.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,7 @@ def _check_inputs(self, sample: Any) -> Any:
7676
if has_any(sample, PIL.Image.Image):
7777
raise TypeError("LinearTransformation does not work on PIL Images")
7878

79-
def _transform(
80-
self, inpt: Union[datapoints.TensorImageType, datapoints.TensorVideoType], params: Dict[str, Any]
81-
) -> torch.Tensor:
82-
# Image instance after linear transformation is not Image anymore due to unknown data range
83-
# Thus we will return Tensor for input Image
84-
79+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
8580
shape = inpt.shape
8681
n = shape[-3] * shape[-2] * shape[-1]
8782
if n != self.transformation_matrix.shape[0]:
@@ -97,11 +92,15 @@ def _transform(
9792
f"Got {inpt.device} vs {self.mean_vector.device}"
9893
)
9994

100-
flat_tensor = inpt.reshape(-1, n) - self.mean_vector
95+
flat_inpt = inpt.reshape(-1, n) - self.mean_vector
96+
97+
transformation_matrix = self.transformation_matrix.to(flat_inpt.dtype)
98+
output = torch.mm(flat_inpt, transformation_matrix)
99+
output = output.reshape(shape)
101100

102-
transformation_matrix = self.transformation_matrix.to(flat_tensor.dtype)
103-
transformed_tensor = torch.mm(flat_tensor, transformation_matrix)
104-
return transformed_tensor.reshape(shape)
101+
if isinstance(inpt, (datapoints.Image, datapoints.Video)):
102+
output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type]
103+
return output
105104

106105

107106
class Normalize(Transform):
@@ -120,7 +119,7 @@ def _check_inputs(self, sample: Any) -> Any:
120119

121120
def _transform(
122121
self, inpt: Union[datapoints.TensorImageType, datapoints.TensorVideoType], params: Dict[str, Any]
123-
) -> torch.Tensor:
122+
) -> Any:
124123
return F.normalize(inpt, mean=self.mean, std=self.std, inplace=self.inplace)
125124

126125

0 commit comments

Comments
 (0)