Skip to content

Commit 1b41525

Browse files
authored
Normalize, LinearTransformation are scriptable (#2645)
* [WIP] All transforms are now derived from torch.nn.Module - Compose, RandomApply, Normalize can be jit scripted * Fixed flake8 * Updated code and docs - added getattr to Lambda and tests - updated code and docs of Compose - added failing test with append/extend on Composed.transforms * Fixed flake8 * Updated code, tests and docs
1 parent 8dfcff7 commit 1b41525

File tree

4 files changed

+132
-28
lines changed

4 files changed

+132
-28
lines changed

docs/source/transforms.rst

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,26 @@ All transformations accept PIL Image, Tensor Image or batch of Tensor Images as
1414
Tensor Images is a tensor of ``(B, C, H, W)`` shape, where ``B`` is a number of images in the batch. Deterministic or
1515
random transformations applied on the batch of Tensor Images identically transform all the images of the batch.
1616

17+
18+
Scriptable transforms
19+
^^^^^^^^^^^^^^^^^^^^^
20+
21+
In order to script the transformations, please use ``torch.nn.Sequential`` instead of :class:`Compose`.
22+
23+
.. code:: python
24+
25+
transforms = torch.nn.Sequential(
26+
transforms.CenterCrop(10),
27+
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
28+
)
29+
scripted_transforms = torch.jit.script(transforms)
30+
31+
Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor`` and does not require
32+
`lambda` functions or ``PIL.Image``.
33+
34+
For any custom transformations to be used with ``torch.jit.script``, they should be derived from ``torch.nn.Module``.
35+
36+
1737
.. autoclass:: Compose
1838

1939
Transforms on PIL Image

test/test_transforms_tensor.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,63 @@ def test_to_grayscale(self):
376376
"RandomGrayscale", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
377377
)
378378

379+
def test_normalize(self):
380+
tensor, _ = self._create_data(26, 34, device=self.device)
381+
batch_tensors = torch.rand(4, 3, 44, 56, device=self.device)
382+
383+
tensor = tensor.to(dtype=torch.float32) / 255.0
384+
# test for class interface
385+
fn = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
386+
scripted_fn = torch.jit.script(fn)
387+
388+
self._test_transform_vs_scripted(fn, scripted_fn, tensor)
389+
self._test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors)
390+
391+
def test_linear_transformation(self):
392+
c, h, w = 3, 24, 32
393+
394+
tensor, _ = self._create_data(h, w, channels=c, device=self.device)
395+
396+
matrix = torch.rand(c * h * w, c * h * w, device=self.device)
397+
mean_vector = torch.rand(c * h * w, device=self.device)
398+
399+
fn = T.LinearTransformation(matrix, mean_vector)
400+
scripted_fn = torch.jit.script(fn)
401+
402+
self._test_transform_vs_scripted(fn, scripted_fn, tensor)
403+
404+
batch_tensors = torch.rand(4, c, h, w, device=self.device)
405+
# We skip some tests from _test_transform_vs_scripted_on_batch as
406+
# results for scripted and non-scripted transformations are not exactly the same
407+
torch.manual_seed(12)
408+
transformed_batch = fn(batch_tensors)
409+
torch.manual_seed(12)
410+
s_transformed_batch = scripted_fn(batch_tensors)
411+
self.assertTrue(transformed_batch.equal(s_transformed_batch))
412+
413+
def test_compose(self):
414+
tensor, _ = self._create_data(26, 34, device=self.device)
415+
tensor = tensor.to(dtype=torch.float32) / 255.0
416+
417+
transforms = T.Compose([
418+
T.CenterCrop(10),
419+
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
420+
])
421+
s_transforms = torch.nn.Sequential(*transforms.transforms)
422+
423+
scripted_fn = torch.jit.script(s_transforms)
424+
torch.manual_seed(12)
425+
transformed_tensor = transforms(tensor)
426+
torch.manual_seed(12)
427+
transformed_tensor_script = scripted_fn(tensor)
428+
self.assertTrue(transformed_tensor.equal(transformed_tensor_script), msg="{}".format(transforms))
429+
430+
t = T.Compose([
431+
lambda x: x,
432+
])
433+
with self.assertRaisesRegex(RuntimeError, r"Could not get name of python class object"):
434+
torch.jit.script(t)
435+
379436

380437
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
381438
class CUDATester(Tester):

torchvision/transforms/functional.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ def to_pil_image(pic, mode=None):
283283
return Image.fromarray(npimg, mode=mode)
284284

285285

286-
def normalize(tensor, mean, std, inplace=False):
286+
def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool = False) -> Tensor:
287287
"""Normalize a tensor image with mean and standard deviation.
288288
289289
.. note::
@@ -292,19 +292,19 @@ def normalize(tensor, mean, std, inplace=False):
292292
See :class:`~torchvision.transforms.Normalize` for more details.
293293
294294
Args:
295-
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
295+
tensor (Tensor): Tensor image of size (C, H, W) or (B, C, H, W) to be normalized.
296296
mean (sequence): Sequence of means for each channel.
297297
std (sequence): Sequence of standard deviations for each channel.
298298
inplace(bool,optional): Bool to make this operation inplace.
299299
300300
Returns:
301301
Tensor: Normalized Tensor image.
302302
"""
303-
if not torch.is_tensor(tensor):
304-
raise TypeError('tensor should be a torch tensor. Got {}.'.format(type(tensor)))
303+
if not isinstance(tensor, torch.Tensor):
304+
raise TypeError('Input tensor should be a torch tensor. Got {}.'.format(type(tensor)))
305305

306-
if tensor.ndimension() != 3:
307-
raise ValueError('Expected tensor to be a tensor image of size (C, H, W). Got tensor.size() = '
306+
if tensor.ndim < 3:
307+
raise ValueError('Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = '
308308
'{}.'.format(tensor.size()))
309309

310310
if not inplace:
@@ -316,9 +316,9 @@ def normalize(tensor, mean, std, inplace=False):
316316
if (std == 0).any():
317317
raise ValueError('std evaluated to zero after conversion to {}, leading to division by zero.'.format(dtype))
318318
if mean.ndim == 1:
319-
mean = mean[:, None, None]
319+
mean = mean.view(-1, 1, 1)
320320
if std.ndim == 1:
321-
std = std[:, None, None]
321+
std = std.view(-1, 1, 1)
322322
tensor.sub_(mean).div_(std)
323323
return tensor
324324

torchvision/transforms/transforms.py

Lines changed: 47 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import random
44
import warnings
55
from collections.abc import Sequence
6-
from typing import Tuple, List, Optional
6+
from typing import Tuple, List, Optional, Any
77

88
import torch
99
from PIL import Image
@@ -33,7 +33,7 @@
3333
}
3434

3535

36-
class Compose(object):
36+
class Compose:
3737
"""Composes several transforms together.
3838
3939
Args:
@@ -44,6 +44,19 @@ class Compose(object):
4444
>>> transforms.CenterCrop(10),
4545
>>> transforms.ToTensor(),
4646
>>> ])
47+
48+
.. note::
49+
In order to script the transformations, please use ``torch.nn.Sequential`` as below.
50+
51+
>>> transforms = torch.nn.Sequential(
52+
>>> transforms.CenterCrop(10),
53+
>>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
54+
>>> )
55+
>>> scripted_transforms = torch.jit.script(transforms)
56+
57+
Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
58+
`lambda` functions or ``PIL.Image``.
59+
4760
"""
4861

4962
def __init__(self, transforms):
@@ -63,7 +76,7 @@ def __repr__(self):
6376
return format_string
6477

6578

66-
class ToTensor(object):
79+
class ToTensor:
6780
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
6881
6982
Converts a PIL Image or numpy.ndarray (H x W x C) in the range
@@ -94,7 +107,7 @@ def __repr__(self):
94107
return self.__class__.__name__ + '()'
95108

96109

97-
class PILToTensor(object):
110+
class PILToTensor:
98111
"""Convert a ``PIL Image`` to a tensor of the same type.
99112
100113
Converts a PIL Image (H x W x C) to a Tensor of shape (C x H x W).
@@ -114,7 +127,7 @@ def __repr__(self):
114127
return self.__class__.__name__ + '()'
115128

116129

117-
class ConvertImageDtype(object):
130+
class ConvertImageDtype:
118131
"""Convert a tensor image to the given ``dtype`` and scale the values accordingly
119132
120133
Args:
@@ -139,7 +152,7 @@ def __call__(self, image: torch.Tensor) -> torch.Tensor:
139152
return F.convert_image_dtype(image, self.dtype)
140153

141154

142-
class ToPILImage(object):
155+
class ToPILImage:
143156
"""Convert a tensor or an ndarray to PIL Image.
144157
145158
Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
@@ -178,7 +191,7 @@ def __repr__(self):
178191
return format_string
179192

180193

181-
class Normalize(object):
194+
class Normalize(torch.nn.Module):
182195
"""Normalize a tensor image with mean and standard deviation.
183196
Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n``
184197
channels, this transform will normalize each channel of the input
@@ -196,11 +209,12 @@ class Normalize(object):
196209
"""
197210

198211
def __init__(self, mean, std, inplace=False):
212+
super().__init__()
199213
self.mean = mean
200214
self.std = std
201215
self.inplace = inplace
202216

203-
def __call__(self, tensor):
217+
def forward(self, tensor: Tensor) -> Tensor:
204218
"""
205219
Args:
206220
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
@@ -358,15 +372,16 @@ def __repr__(self):
358372
format(self.padding, self.fill, self.padding_mode)
359373

360374

361-
class Lambda(object):
375+
class Lambda:
362376
"""Apply a user-defined lambda as a transform.
363377
364378
Args:
365379
lambd (function): Lambda/function to be used for transform.
366380
"""
367381

368382
def __init__(self, lambd):
369-
assert callable(lambd), repr(type(lambd).__name__) + " object is not callable"
383+
if not callable(lambd):
384+
raise TypeError("Argument lambd should be callable, got {}".format(repr(type(lambd).__name__)))
370385
self.lambd = lambd
371386

372387
def __call__(self, img):
@@ -376,7 +391,7 @@ def __repr__(self):
376391
return self.__class__.__name__ + '()'
377392

378393

379-
class RandomTransforms(object):
394+
class RandomTransforms:
380395
"""Base class for a list of transformations with randomness
381396
382397
Args:
@@ -408,7 +423,7 @@ class RandomApply(RandomTransforms):
408423
"""
409424

410425
def __init__(self, transforms, p=0.5):
411-
super(RandomApply, self).__init__(transforms)
426+
super().__init__(transforms)
412427
self.p = p
413428

414429
def __call__(self, img):
@@ -897,7 +912,7 @@ def __repr__(self):
897912
return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format(self.size, self.vertical_flip)
898913

899914

900-
class LinearTransformation(object):
915+
class LinearTransformation(torch.nn.Module):
901916
"""Transform a tensor image with a square transformation matrix and a mean_vector computed
902917
offline.
903918
Given transformation_matrix and mean_vector, will flatten the torch.*Tensor and
@@ -916,6 +931,7 @@ class LinearTransformation(object):
916931
"""
917932

918933
def __init__(self, transformation_matrix, mean_vector):
934+
super().__init__()
919935
if transformation_matrix.size(0) != transformation_matrix.size(1):
920936
raise ValueError("transformation_matrix should be square. Got " +
921937
"[{} x {}] rectangular matrix.".format(*transformation_matrix.size()))
@@ -925,24 +941,35 @@ def __init__(self, transformation_matrix, mean_vector):
925941
" as any one of the dimensions of the transformation_matrix [{}]"
926942
.format(tuple(transformation_matrix.size())))
927943

944+
if transformation_matrix.device != mean_vector.device:
945+
raise ValueError("Input tensors should be on the same device. Got {} and {}"
946+
.format(transformation_matrix.device, mean_vector.device))
947+
928948
self.transformation_matrix = transformation_matrix
929949
self.mean_vector = mean_vector
930950

931-
def __call__(self, tensor):
951+
def forward(self, tensor: Tensor) -> Tensor:
932952
"""
933953
Args:
934954
tensor (Tensor): Tensor image of size (C, H, W) to be whitened.
935955
936956
Returns:
937957
Tensor: Transformed image.
938958
"""
939-
if tensor.size(0) * tensor.size(1) * tensor.size(2) != self.transformation_matrix.size(0):
940-
raise ValueError("tensor and transformation matrix have incompatible shape." +
941-
"[{} x {} x {}] != ".format(*tensor.size()) +
942-
"{}".format(self.transformation_matrix.size(0)))
943-
flat_tensor = tensor.view(1, -1) - self.mean_vector
959+
shape = tensor.shape
960+
n = shape[-3] * shape[-2] * shape[-1]
961+
if n != self.transformation_matrix.shape[0]:
962+
raise ValueError("Input tensor and transformation matrix have incompatible shape." +
963+
"[{} x {} x {}] != ".format(shape[-3], shape[-2], shape[-1]) +
964+
"{}".format(self.transformation_matrix.shape[0]))
965+
966+
if tensor.device.type != self.mean_vector.device.type:
967+
raise ValueError("Input tensor should be on the same device as transformation matrix and mean vector. "
968+
"Got {} vs {}".format(tensor.device, self.mean_vector.device))
969+
970+
flat_tensor = tensor.view(-1, n) - self.mean_vector
944971
transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix)
945-
tensor = transformed_tensor.view(tensor.size())
972+
tensor = transformed_tensor.view(shape)
946973
return tensor
947974

948975
def __repr__(self):

0 commit comments

Comments
 (0)