Skip to content

Commit f4b17fd

Browse files
committed
add cvtransforms
1 parent 97871cd commit f4b17fd

File tree

2 files changed

+221
-0
lines changed

2 files changed

+221
-0
lines changed

test/test_cvtransforms.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import cvtransforms
2+
import transforms
3+
import unittest
4+
import numpy as np
5+
import torch
6+
import cv2
7+
8+
class TestOpenCVTransforms(unittest.TestCase):
9+
def testScale(self):
10+
size = 43
11+
w, h = 68, 54
12+
img = np.random.randn(h, w, 3)
13+
tr = cvtransforms.Scale(size)
14+
res = tr(img)
15+
self.assertEqual(res.shape[0], size)
16+
self.assertEqual(res.shape[1], h)
17+
18+
def testCenterCrop(self):
19+
size = 43
20+
w, h = 68, 54
21+
img = np.random.randn(h, w, 3)
22+
tr = cvtransforms.CenterCrop(size)
23+
res = tr(img)
24+
self.assertEqual(res.shape[0], size)
25+
self.assertEqual(res.shape[1], size)
26+
27+
def testNormalize(self):
28+
meanstd = dict(mean=[1,2,3], std=[1,1,1])
29+
normalize = transforms.Normalize(**meanstd)
30+
cvnormalize = cvtransforms.Normalize(**meanstd)
31+
32+
w, h = 68, 54
33+
img = np.random.randn(h, w, 3)
34+
for i in range(3):
35+
img[:,:,i] = i+1
36+
res_th = normalize(torch.from_numpy(img).clone().permute(2,0,1)).permute(1,2,0).numpy()
37+
res_np = cvnormalize(img)
38+
self.assertEqual(np.abs(res_np - res_th).sum(), 0)
39+
40+
def testFlip(self):
41+
w, h = 12, 10
42+
img = np.random.randn(h, w, 1)
43+
img[:,:6,:] = 0
44+
img[:,6:,:] = 1
45+
46+
flip = img
47+
while id(flip) == id(img):
48+
flip = cvtransforms.RandomHorizontalFlip()(img)
49+
self.assertEqual(flip[:,:6,:].mean(), 1)
50+
self.assertEqual(flip[:,6:,:].mean(), 0)
51+
52+
def testPadding(self):
53+
w, h = 12, 10
54+
img = np.random.randn(h, w, 1)
55+
img[:,:6,:] = 0
56+
img[:,6:,:] = 1
57+
58+
padded = cvtransforms.Pad(2, cv2.BORDER_REFLECT)(img)
59+
self.assertEqual(padded[:,:8,:].mean(), 0)
60+
self.assertEqual(padded[:,8:,:].mean(), 1)
61+
62+
63+
if __name__ == '__main__':
64+
unittest.main()

torchvision/cvtransforms.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
from __future__ import division
2+
import math
3+
import random
4+
import numpy as np
5+
import numbers
6+
import cv2
7+
8+
9+
class Normalize(object):
10+
"""Given mean: (R, G, B) and std: (R, G, B),
11+
will normalize each channel of the np.ndarray, i.e.
12+
channel = (channel - mean) / std
13+
"""
14+
15+
def __init__(self, mean, std):
16+
self.mean = mean
17+
self.std = std
18+
19+
def __call__(self, tensor):
20+
return (tensor - self.mean) / self.std
21+
22+
23+
class Scale(object):
24+
"""Rescales the input PIL.Image to the given 'size'.
25+
'size' will be the size of the smaller edge.
26+
For example, if height > width, then image will be
27+
rescaled to (size * height / width, size)
28+
size: size of the smaller edge
29+
interpolation: Default: PIL.Image.BILINEAR
30+
"""
31+
def __init__(self, size, interpolation=cv2.INTER_CUBIC):
32+
self.size = size
33+
self.interpolation = interpolation
34+
35+
def __call__(self, img):
36+
w, h = img.shape[1], img.shape[0]
37+
if (w <= h and w == self.size) or (h <= w and h == self.size):
38+
return img
39+
if w < h:
40+
ow = self.size
41+
oh = int(float(self.size) * h / w)
42+
else:
43+
oh = self.size
44+
ow = int(float(self.size) * w / h)
45+
return cv2.resize(img, dsize=(ow, oh),
46+
interpolation=self.interpolation)
47+
48+
49+
class CenterCrop(object):
50+
"""Crops the given np.ndarray at the center to have a region of
51+
the given size. size can be a tuple (target_height, target_width)
52+
or an integer, in which case the target will be of a square shape
53+
(size, size)
54+
"""
55+
def __init__(self, size):
56+
if isinstance(size, numbers.Number):
57+
self.size = (int(size), int(size))
58+
else:
59+
self.size = size
60+
61+
def __call__(self, img):
62+
w, h = img.shape[1], img.shape[0]
63+
th, tw = self.size
64+
x1 = int(round((w - tw) / 2.))
65+
y1 = int(round((h - th) / 2.))
66+
return img[y1:y1+th, x1:x1+tw, :]
67+
68+
69+
class Pad(object):
70+
"""Pads the given np.ndarray on all sides with the given "pad" value."""
71+
72+
def __init__(self, padding, borderType=cv2.BORDER_CONSTANT, borderValue=0):
73+
assert isinstance(padding, numbers.Number)
74+
self.padding = padding
75+
self.borderType = borderType
76+
self.borderValue = borderValue
77+
78+
def __call__(self, img):
79+
if self.padding == 0:
80+
return img
81+
p = self.padding
82+
res = cv2.copyMakeBorder(img, p, p, p, p,
83+
borderType=self.borderType,
84+
value=self.borderValue)
85+
return res[:, :, np.newaxis] if np.ndim(res) == 2 else res
86+
87+
88+
class RandomCrop(object):
89+
"""Crops the given np.ndarray at a random location to have a region of
90+
the given size. size can be a tuple (target_height, target_width)
91+
or an integer, in which case the target will be of a square shape
92+
(size, size)
93+
"""
94+
def __init__(self, size):
95+
if isinstance(size, numbers.Number):
96+
self.size = (int(size), int(size))
97+
else:
98+
self.size = size
99+
100+
def __call__(self, img):
101+
w, h = img.shape[1], img.shape[0]
102+
th, tw = self.size
103+
if w == tw and h == th:
104+
return img
105+
106+
x1 = random.randint(0, w - tw)
107+
y1 = random.randint(0, h - th)
108+
return img[y1:y1+th, x1:x1+tw, :]
109+
110+
111+
class RandomHorizontalFlip(object):
112+
"""Randomly horizontally flips the given np.ndarray with a probability of 0.5
113+
"""
114+
def __call__(self, img):
115+
if random.random() < 0.5:
116+
return cv2.flip(img, 1).reshape(img.shape)
117+
return img
118+
119+
120+
class RandomSizedCrop(object):
121+
"""Random crop the given np.ndarray to a random size of (0.08 to 1.0) of the original size
122+
and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio
123+
This is popularly used to train the Inception networks
124+
size: size of the smaller edge
125+
interpolation: Default: cv2.INTER_CUBIC
126+
"""
127+
def __init__(self, size, interpolation=cv2.INTER_CUBIC):
128+
self.size = size
129+
self.interpolation = interpolation
130+
131+
def __call__(self, img):
132+
for attempt in range(10):
133+
area = img.shape[0] * img.shape[1]
134+
target_area = random.uniform(0.08, 1.0) * area
135+
aspect_ratio = random.uniform(3. / 4., 4. / 3.)
136+
137+
w = int(round(math.sqrt(target_area * aspect_ratio)))
138+
h = int(round(math.sqrt(target_area / aspect_ratio)))
139+
140+
if random.random() < 0.5:
141+
w, h = h, w
142+
143+
if w <= img.shape[1] and h <= img.shape[0]:
144+
x1 = random.randint(0, img.shape[1] - w)
145+
y1 = random.randint(0, img.shape[0] - h)
146+
147+
img = img[y1:y1+h, x1:x1+w, :]
148+
assert img.shape[0] == h and img.shape[1] == w
149+
150+
return cv2.resize(img, (self.size, self.size),
151+
interpolation=self.interpolation)
152+
153+
# Fallback
154+
scale = Scale(self.size, interpolation=self.interpolation)
155+
crop = CenterCrop(self.size)
156+
return crop(scale(img))
157+

0 commit comments

Comments
 (0)