Skip to content

Commit 98b9aa5

Browse files
authored
Merge pull request #8 from pytorch/lambda
adding lambda transform
2 parents 7ab4204 + 6b175db commit 98b9aa5

File tree

3 files changed

+29
-3
lines changed

3 files changed

+29
-3
lines changed

README.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,6 @@ This is popularly used to train the Inception networks
194194
- size: size of the smaller edge
195195
- interpolation: Default: PIL.Image.BILINEAR
196196

197-
198197
### `Pad(padding, fill=0)`
199198
Pads the given image on each side with `padding` number of pixels, and the padding pixels are filled with
200199
pixel value `fill`.
@@ -209,6 +208,14 @@ Given mean: (R, G, B) and std: (R, G, B), will normalize each channel of the tor
209208
- `ToTensor()` - Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
210209
- `ToPILImage()` - Converts a torch.*Tensor of range [0, 1] and shape C x H x W or numpy ndarray of dtype=uint8, range[0, 255] and shape H x W x C to a PIL.Image of range [0, 255]
211210

211+
## Generic Transofrms
212+
### `Lambda(lambda)`
213+
Given a Python lambda, applies it to the input `img` and returns it.
214+
For example:
215+
216+
```python
217+
transforms.Lambda(lambda x: x.add(10))
218+
```
212219

213220
# Utils
214221

test/test_transforms.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,19 @@ def test_pad(self):
100100
transforms.Pad(padding),
101101
transforms.ToTensor(),
102102
])(img)
103-
print(height, width, padding)
104-
print(result.size(1), result.size(2))
105103
assert result.size(1) == height + 2*padding
106104
assert result.size(2) == width + 2*padding
105+
106+
def test_lambda(self):
107+
trans = transforms.Lambda(lambda x: x.add(10))
108+
x = torch.randn(10)
109+
y = trans(x)
110+
assert(y.equal(torch.add(x, 10)))
111+
112+
trans = transforms.Lambda(lambda x: x.add_(10))
113+
x = torch.randn(10)
114+
y = trans(x)
115+
assert(y.equal(x))
107116

108117

109118
if __name__ == '__main__':

torchvision/transforms.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from PIL import Image, ImageOps
66
import numpy as np
77
import numbers
8+
import types
89

910
class Compose(object):
1011
""" Composes several transforms together.
@@ -126,6 +127,15 @@ def __init__(self, padding, fill=0):
126127
def __call__(self, img):
127128
return ImageOps.expand(img, border=self.padding, fill=self.fill)
128129

130+
class Lambda(object):
131+
"""Applies a lambda as a transform"""
132+
def __init__(self, lambd):
133+
assert type(lambd) is types.LambdaType
134+
self.lambd = lambd
135+
136+
def __call__(self, img):
137+
return self.lambd(img)
138+
129139

130140
class RandomCrop(object):
131141
"""Crops the given PIL.Image at a random location to have a region of

0 commit comments

Comments
 (0)