1
1
import torch
2
2
import torchvision .transforms as transforms
3
- import torchvision .datasets as datasets
4
- import numpy as np
5
3
import unittest
6
4
import random
7
5
6
+
8
7
class Tester (unittest .TestCase ):
9
8
def test_crop (self ):
10
9
height = random .randint (10 , 32 ) * 2
11
10
width = random .randint (10 , 32 ) * 2
12
11
oheight = random .randint (5 , (height - 2 ) / 2 ) * 2
13
- owidth = random .randint (5 , (width - 2 ) / 2 ) * 2
14
-
12
+ owidth = random .randint (5 , (width - 2 ) / 2 ) * 2
13
+
15
14
img = torch .ones (3 , height , width )
16
- oh1 = (height - oheight ) / 2
17
- ow1 = (width - owidth ) / 2
18
- imgnarrow = img [:, oh1 :oh1 + oheight , ow1 :ow1 + owidth ]
15
+ oh1 = (height - oheight ) // 2
16
+ ow1 = (width - owidth ) // 2
17
+ imgnarrow = img [:, oh1 :oh1 + oheight , ow1 :ow1 + owidth ]
19
18
imgnarrow .fill_ (0 )
20
19
result = transforms .Compose ([
21
20
transforms .ToPILImage (),
22
21
transforms .CenterCrop ((oheight , owidth )),
23
22
transforms .ToTensor (),
24
23
])(img )
25
24
assert result .sum () == 0 , "height: " + str (height ) + " width: " \
26
- + str ( width ) + " oheight: " + str (oheight ) + " owidth: " + str (owidth )
25
+ + str (width ) + " oheight: " + str (oheight ) + " owidth: " + str (owidth )
27
26
oheight += 1
28
27
owidth += 1
29
28
result = transforms .Compose ([
@@ -33,25 +32,25 @@ def test_crop(self):
33
32
])(img )
34
33
sum1 = result .sum ()
35
34
assert sum1 > 1 , "height: " + str (height ) + " width: " \
36
- + str ( width ) + " oheight: " + str (oheight ) + " owidth: " + str (owidth )
35
+ + str (width ) + " oheight: " + str (oheight ) + " owidth: " + str (owidth )
37
36
oheight += 1
38
- owidth += 1
37
+ owidth += 1
39
38
result = transforms .Compose ([
40
39
transforms .ToPILImage (),
41
40
transforms .CenterCrop ((oheight , owidth )),
42
41
transforms .ToTensor (),
43
42
])(img )
44
43
sum2 = result .sum ()
45
44
assert sum2 > 0 , "height: " + str (height ) + " width: " \
46
- + str ( width ) + " oheight: " + str (oheight ) + " owidth: " + str (owidth )
45
+ + str (width ) + " oheight: " + str (oheight ) + " owidth: " + str (owidth )
47
46
assert sum2 > sum1 , "height: " + str (height ) + " width: " \
48
- + str ( width ) + " oheight: " + str (oheight ) + " owidth: " + str (owidth )
47
+ + str (width ) + " oheight: " + str (oheight ) + " owidth: " + str (owidth )
49
48
50
49
def test_scale (self ):
51
50
height = random .randint (24 , 32 ) * 2
52
51
width = random .randint (24 , 32 ) * 2
53
52
osize = random .randint (5 , 12 ) * 2
54
-
53
+
55
54
img = torch .ones (3 , height , width )
56
55
result = transforms .Compose ([
57
56
transforms .ToPILImage (),
@@ -63,15 +62,15 @@ def test_scale(self):
63
62
# print result.size()
64
63
assert osize in result .size ()
65
64
if height < width :
66
- assert result .size (1 ) <= result .size (2 )
65
+ assert result .size (1 ) <= result .size (2 )
67
66
elif width < height :
68
67
assert result .size (1 ) >= result .size (2 )
69
68
70
69
def test_random_crop (self ):
71
70
height = random .randint (10 , 32 ) * 2
72
71
width = random .randint (10 , 32 ) * 2
73
72
oheight = random .randint (5 , (height - 2 ) / 2 ) * 2
74
- owidth = random .randint (5 , (width - 2 ) / 2 ) * 2
73
+ owidth = random .randint (5 , (width - 2 ) / 2 ) * 2
75
74
img = torch .ones (3 , height , width )
76
75
result = transforms .Compose ([
77
76
transforms .ToPILImage (),
@@ -100,20 +99,20 @@ def test_pad(self):
100
99
transforms .Pad (padding ),
101
100
transforms .ToTensor (),
102
101
])(img )
103
- assert result .size (1 ) == height + 2 * padding
104
- assert result .size (2 ) == width + 2 * padding
102
+ assert result .size (1 ) == height + 2 * padding
103
+ assert result .size (2 ) == width + 2 * padding
105
104
106
105
def test_lambda (self ):
107
106
trans = transforms .Lambda (lambda x : x .add (10 ))
108
107
x = torch .randn (10 )
109
108
y = trans (x )
110
- assert (y .equal (torch .add (x , 10 )))
109
+ assert (y .equal (torch .add (x , 10 )))
111
110
112
111
trans = transforms .Lambda (lambda x : x .add_ (10 ))
113
112
x = torch .randn (10 )
114
113
y = trans (x )
115
- assert (y .equal (x ))
116
-
114
+ assert (y .equal (x ))
115
+
117
116
118
117
if __name__ == '__main__' :
119
118
unittest .main ()
0 commit comments