Skip to content

Commit e56670d

Browse files
authored
Merge pull request #51 from alykhantejani/fix_test_transforms
Use integer division in tests/test_transforms for array slice indices + various PEP-8 fixes
2 parents 6c7733f + 1610268 commit e56670d

File tree

1 file changed

+19
-20
lines changed

1 file changed

+19
-20
lines changed

test/test_transforms.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,28 @@
11
import torch
22
import torchvision.transforms as transforms
3-
import torchvision.datasets as datasets
4-
import numpy as np
53
import unittest
64
import random
75

6+
87
class Tester(unittest.TestCase):
98
def test_crop(self):
109
height = random.randint(10, 32) * 2
1110
width = random.randint(10, 32) * 2
1211
oheight = random.randint(5, (height - 2) / 2) * 2
13-
owidth = random.randint(5, (width - 2) / 2) * 2
14-
12+
owidth = random.randint(5, (width - 2) / 2) * 2
13+
1514
img = torch.ones(3, height, width)
16-
oh1 = (height - oheight) / 2
17-
ow1 = (width - owidth) / 2
18-
imgnarrow = img[:, oh1 :oh1 + oheight, ow1 :ow1 + owidth]
15+
oh1 = (height - oheight) // 2
16+
ow1 = (width - owidth) // 2
17+
imgnarrow = img[:, oh1:oh1 + oheight, ow1:ow1 + owidth]
1918
imgnarrow.fill_(0)
2019
result = transforms.Compose([
2120
transforms.ToPILImage(),
2221
transforms.CenterCrop((oheight, owidth)),
2322
transforms.ToTensor(),
2423
])(img)
2524
assert result.sum() == 0, "height: " + str(height) + " width: " \
26-
+ str( width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
25+
+ str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
2726
oheight += 1
2827
owidth += 1
2928
result = transforms.Compose([
@@ -33,25 +32,25 @@ def test_crop(self):
3332
])(img)
3433
sum1 = result.sum()
3534
assert sum1 > 1, "height: " + str(height) + " width: " \
36-
+ str( width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
35+
+ str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
3736
oheight += 1
38-
owidth += 1
37+
owidth += 1
3938
result = transforms.Compose([
4039
transforms.ToPILImage(),
4140
transforms.CenterCrop((oheight, owidth)),
4241
transforms.ToTensor(),
4342
])(img)
4443
sum2 = result.sum()
4544
assert sum2 > 0, "height: " + str(height) + " width: " \
46-
+ str( width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
45+
+ str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
4746
assert sum2 > sum1, "height: " + str(height) + " width: " \
48-
+ str( width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
47+
+ str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
4948

5049
def test_scale(self):
5150
height = random.randint(24, 32) * 2
5251
width = random.randint(24, 32) * 2
5352
osize = random.randint(5, 12) * 2
54-
53+
5554
img = torch.ones(3, height, width)
5655
result = transforms.Compose([
5756
transforms.ToPILImage(),
@@ -63,15 +62,15 @@ def test_scale(self):
6362
# print result.size()
6463
assert osize in result.size()
6564
if height < width:
66-
assert result.size(1) <= result.size(2)
65+
assert result.size(1) <= result.size(2)
6766
elif width < height:
6867
assert result.size(1) >= result.size(2)
6968

7069
def test_random_crop(self):
7170
height = random.randint(10, 32) * 2
7271
width = random.randint(10, 32) * 2
7372
oheight = random.randint(5, (height - 2) / 2) * 2
74-
owidth = random.randint(5, (width - 2) / 2) * 2
73+
owidth = random.randint(5, (width - 2) / 2) * 2
7574
img = torch.ones(3, height, width)
7675
result = transforms.Compose([
7776
transforms.ToPILImage(),
@@ -100,20 +99,20 @@ def test_pad(self):
10099
transforms.Pad(padding),
101100
transforms.ToTensor(),
102101
])(img)
103-
assert result.size(1) == height + 2*padding
104-
assert result.size(2) == width + 2*padding
102+
assert result.size(1) == height + 2 * padding
103+
assert result.size(2) == width + 2 * padding
105104

106105
def test_lambda(self):
107106
trans = transforms.Lambda(lambda x: x.add(10))
108107
x = torch.randn(10)
109108
y = trans(x)
110-
assert(y.equal(torch.add(x, 10)))
109+
assert (y.equal(torch.add(x, 10)))
111110

112111
trans = transforms.Lambda(lambda x: x.add_(10))
113112
x = torch.randn(10)
114113
y = trans(x)
115-
assert(y.equal(x))
116-
114+
assert (y.equal(x))
115+
117116

118117
if __name__ == '__main__':
119118
unittest.main()

0 commit comments

Comments
 (0)