3
3
import random
4
4
import warnings
5
5
from collections .abc import Sequence
6
- from typing import Tuple , List , Optional
6
+ from typing import Tuple , List , Optional , Any
7
7
8
8
import torch
9
9
from PIL import Image
33
33
}
34
34
35
35
36
- class Compose ( object ) :
36
+ class Compose :
37
37
"""Composes several transforms together.
38
38
39
39
Args:
@@ -44,6 +44,19 @@ class Compose(object):
44
44
>>> transforms.CenterCrop(10),
45
45
>>> transforms.ToTensor(),
46
46
>>> ])
47
+
48
+ .. note::
49
+ In order to script the transformations, please use ``torch.nn.Sequential`` as below.
50
+
51
+ >>> transforms = torch.nn.Sequential(
52
+ >>> transforms.CenterCrop(10),
53
+ >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
54
+ >>> )
55
+ >>> scripted_transforms = torch.jit.script(transforms)
56
+
57
+ Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
58
+ `lambda` functions or ``PIL.Image``.
59
+
47
60
"""
48
61
49
62
def __init__ (self , transforms ):
@@ -63,7 +76,7 @@ def __repr__(self):
63
76
return format_string
64
77
65
78
66
- class ToTensor ( object ) :
79
+ class ToTensor :
67
80
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
68
81
69
82
Converts a PIL Image or numpy.ndarray (H x W x C) in the range
@@ -94,7 +107,7 @@ def __repr__(self):
94
107
return self .__class__ .__name__ + '()'
95
108
96
109
97
- class PILToTensor ( object ) :
110
+ class PILToTensor :
98
111
"""Convert a ``PIL Image`` to a tensor of the same type.
99
112
100
113
Converts a PIL Image (H x W x C) to a Tensor of shape (C x H x W).
@@ -114,7 +127,7 @@ def __repr__(self):
114
127
return self .__class__ .__name__ + '()'
115
128
116
129
117
- class ConvertImageDtype ( object ) :
130
+ class ConvertImageDtype :
118
131
"""Convert a tensor image to the given ``dtype`` and scale the values accordingly
119
132
120
133
Args:
@@ -139,7 +152,7 @@ def __call__(self, image: torch.Tensor) -> torch.Tensor:
139
152
return F .convert_image_dtype (image , self .dtype )
140
153
141
154
142
- class ToPILImage ( object ) :
155
+ class ToPILImage :
143
156
"""Convert a tensor or an ndarray to PIL Image.
144
157
145
158
Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
@@ -178,7 +191,7 @@ def __repr__(self):
178
191
return format_string
179
192
180
193
181
- class Normalize (object ):
194
+ class Normalize (torch . nn . Module ):
182
195
"""Normalize a tensor image with mean and standard deviation.
183
196
Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n``
184
197
channels, this transform will normalize each channel of the input
@@ -196,11 +209,12 @@ class Normalize(object):
196
209
"""
197
210
198
211
def __init__ (self , mean , std , inplace = False ):
212
+ super ().__init__ ()
199
213
self .mean = mean
200
214
self .std = std
201
215
self .inplace = inplace
202
216
203
- def __call__ (self , tensor ) :
217
+ def forward (self , tensor : Tensor ) -> Tensor :
204
218
"""
205
219
Args:
206
220
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
@@ -358,15 +372,16 @@ def __repr__(self):
358
372
format (self .padding , self .fill , self .padding_mode )
359
373
360
374
361
- class Lambda ( object ) :
375
+ class Lambda :
362
376
"""Apply a user-defined lambda as a transform.
363
377
364
378
Args:
365
379
lambd (function): Lambda/function to be used for transform.
366
380
"""
367
381
368
382
def __init__ (self , lambd ):
369
- assert callable (lambd ), repr (type (lambd ).__name__ ) + " object is not callable"
383
+ if not callable (lambd ):
384
+ raise TypeError ("Argument lambd should be callable, got {}" .format (repr (type (lambd ).__name__ )))
370
385
self .lambd = lambd
371
386
372
387
def __call__ (self , img ):
@@ -376,7 +391,7 @@ def __repr__(self):
376
391
return self .__class__ .__name__ + '()'
377
392
378
393
379
- class RandomTransforms ( object ) :
394
+ class RandomTransforms :
380
395
"""Base class for a list of transformations with randomness
381
396
382
397
Args:
@@ -408,7 +423,7 @@ class RandomApply(RandomTransforms):
408
423
"""
409
424
410
425
def __init__ (self , transforms , p = 0.5 ):
411
- super (RandomApply , self ).__init__ (transforms )
426
+ super ().__init__ (transforms )
412
427
self .p = p
413
428
414
429
def __call__ (self , img ):
@@ -897,7 +912,7 @@ def __repr__(self):
897
912
return self .__class__ .__name__ + '(size={0}, vertical_flip={1})' .format (self .size , self .vertical_flip )
898
913
899
914
900
- class LinearTransformation (object ):
915
+ class LinearTransformation (torch . nn . Module ):
901
916
"""Transform a tensor image with a square transformation matrix and a mean_vector computed
902
917
offline.
903
918
Given transformation_matrix and mean_vector, will flatten the torch.*Tensor and
@@ -916,6 +931,7 @@ class LinearTransformation(object):
916
931
"""
917
932
918
933
def __init__ (self , transformation_matrix , mean_vector ):
934
+ super ().__init__ ()
919
935
if transformation_matrix .size (0 ) != transformation_matrix .size (1 ):
920
936
raise ValueError ("transformation_matrix should be square. Got " +
921
937
"[{} x {}] rectangular matrix." .format (* transformation_matrix .size ()))
@@ -925,24 +941,35 @@ def __init__(self, transformation_matrix, mean_vector):
925
941
" as any one of the dimensions of the transformation_matrix [{}]"
926
942
.format (tuple (transformation_matrix .size ())))
927
943
944
+ if transformation_matrix .device != mean_vector .device :
945
+ raise ValueError ("Input tensors should be on the same device. Got {} and {}"
946
+ .format (transformation_matrix .device , mean_vector .device ))
947
+
928
948
self .transformation_matrix = transformation_matrix
929
949
self .mean_vector = mean_vector
930
950
931
- def __call__ (self , tensor ) :
951
+ def forward (self , tensor : Tensor ) -> Tensor :
932
952
"""
933
953
Args:
934
954
tensor (Tensor): Tensor image of size (C, H, W) to be whitened.
935
955
936
956
Returns:
937
957
Tensor: Transformed image.
938
958
"""
939
- if tensor .size (0 ) * tensor .size (1 ) * tensor .size (2 ) != self .transformation_matrix .size (0 ):
940
- raise ValueError ("tensor and transformation matrix have incompatible shape." +
941
- "[{} x {} x {}] != " .format (* tensor .size ()) +
942
- "{}" .format (self .transformation_matrix .size (0 )))
943
- flat_tensor = tensor .view (1 , - 1 ) - self .mean_vector
959
+ shape = tensor .shape
960
+ n = shape [- 3 ] * shape [- 2 ] * shape [- 1 ]
961
+ if n != self .transformation_matrix .shape [0 ]:
962
+ raise ValueError ("Input tensor and transformation matrix have incompatible shape." +
963
+ "[{} x {} x {}] != " .format (shape [- 3 ], shape [- 2 ], shape [- 1 ]) +
964
+ "{}" .format (self .transformation_matrix .shape [0 ]))
965
+
966
+ if tensor .device .type != self .mean_vector .device .type :
967
+ raise ValueError ("Input tensor should be on the same device as transformation matrix and mean vector. "
968
+ "Got {} vs {}" .format (tensor .device , self .mean_vector .device ))
969
+
970
+ flat_tensor = tensor .view (- 1 , n ) - self .mean_vector
944
971
transformed_tensor = torch .mm (flat_tensor , self .transformation_matrix )
945
- tensor = transformed_tensor .view (tensor . size () )
972
+ tensor = transformed_tensor .view (shape )
946
973
return tensor
947
974
948
975
def __repr__ (self ):
0 commit comments