Skip to content

Commit 4ecf860

Browse files
committed
Merge branch 'master' of https://github.com/pytorch/vision into inception
2 parents f640996 + 55b5a6b commit 4ecf860

File tree

10 files changed

+111
-45
lines changed

10 files changed

+111
-45
lines changed

docs/source/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def patched_make_field(self, types, domain, items, **kw):
208208
# `kw` catches `env=None` needed for newer sphinx while maintaining
209209
# backwards compatibility when passed along further down!
210210

211-
# type: (List, unicode, Tuple) -> nodes.field
211+
# type: (list, unicode, tuple) -> nodes.field
212212
def handle_item(fieldarg, content):
213213
par = nodes.paragraph()
214214
par += addnodes.literal_strong('', fieldarg) # Patch: this line added

docs/source/datasets.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ EMNIST
4545

4646
.. autoclass:: EMNIST
4747

48+
FakeData
49+
~~~~~~~~
50+
51+
.. autoclass:: FakeData
52+
4853
COCO
4954
~~~~
5055

test/test_transforms.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from __future__ import division
12
import torch
23
import torchvision.transforms as transforms
34
import torchvision.transforms.functional as F
@@ -130,6 +131,25 @@ def test_ten_crop(self):
130131
assert len(results) == 10
131132
assert expected_output == results
132133

134+
def test_randomresized_params(self):
135+
height = random.randint(24, 32) * 2
136+
width = random.randint(24, 32) * 2
137+
img = torch.ones(3, height, width)
138+
to_pil_image = transforms.ToPILImage()
139+
img = to_pil_image(img)
140+
size = 100
141+
epsilon = 0.05
142+
for i in range(10):
143+
scale_min = round(random.random(), 2)
144+
scale_range = (scale_min, scale_min + round(random.random(), 2))
145+
aspect_min = round(random.random(), 2)
146+
aspect_ratio_range = (aspect_min, aspect_min + round(random.random(), 2))
147+
randresizecrop = transforms.RandomResizedCrop(size, scale_range, aspect_ratio_range)
148+
_, _, h, w = randresizecrop.get_params(img, scale_range, aspect_ratio_range)
149+
aspect_ratio_obtained = w / h
150+
assert (min(aspect_ratio_range) - epsilon <= aspect_ratio_obtained <= max(aspect_ratio_range) + epsilon or
151+
aspect_ratio_obtained == 1.0)
152+
133153
def test_resize(self):
134154
height = random.randint(24, 32) * 2
135155
width = random.randint(24, 32) * 2
@@ -990,10 +1010,10 @@ def test_rotate(self):
9901010
assert np.all(np.array(result_a) == np.array(result_b))
9911011

9921012
def test_affine(self):
993-
input_img = np.zeros((200, 200, 3), dtype=np.uint8)
1013+
input_img = np.zeros((40, 40, 3), dtype=np.uint8)
9941014
pts = []
995-
cnt = [100, 100]
996-
for pt in [(80, 80), (100, 80), (100, 100)]:
1015+
cnt = [20, 20]
1016+
for pt in [(16, 16), (20, 16), (20, 20)]:
9971017
for i in range(-5, 5):
9981018
for j in range(-5, 5):
9991019
input_img[pt[0] + i, pt[1] + j, :] = [255, 155, 55]
@@ -1028,7 +1048,7 @@ def _test_transformation(a, t, s, sh):
10281048
translate=t, scale=s, shear=sh))
10291049
assert np.sum(np.abs(true_matrix - result_matrix)) < 1e-10
10301050
# 2) Perform inverse mapping:
1031-
true_result = np.zeros((200, 200, 3), dtype=np.uint8)
1051+
true_result = np.zeros((40, 40, 3), dtype=np.uint8)
10321052
inv_true_matrix = np.linalg.inv(true_matrix)
10331053
for y in range(true_result.shape[0]):
10341054
for x in range(true_result.shape[1]):

torchvision/datasets/mnist.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from __future__ import print_function
2+
import warnings
23
import torch.utils.data as data
34
from PIL import Image
45
import os
@@ -37,6 +38,26 @@ class MNIST(data.Dataset):
3738
classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
3839
'5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
3940

41+
@property
42+
def train_labels(self):
43+
warnings.warn("train_labels has been renamed targets")
44+
return self.targets
45+
46+
@property
47+
def test_labels(self):
48+
warnings.warn("test_labels has been renamed targets")
49+
return self.targets
50+
51+
@property
52+
def train_data(self):
53+
warnings.warn("train_data has been renamed data")
54+
return self.data
55+
56+
@property
57+
def test_data(self):
58+
warnings.warn("test_data has been renamed data")
59+
return self.data
60+
4061
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
4162
self.root = os.path.expanduser(root)
4263
self.transform = transform
@@ -205,7 +226,7 @@ class KMNIST(MNIST):
205226

206227

207228
class EMNIST(MNIST):
208-
"""`EMNIST <https://www.nist.gov/itl/iad/image-group/emnist-dataset/>`_ Dataset.
229+
"""`EMNIST <https://www.westernsydney.edu.au/bens/home/reproducible_research/emnist>`_ Dataset.
209230
210231
Args:
211232
root (string): Root directory of dataset where ``processed/training.pt``
@@ -223,7 +244,8 @@ class EMNIST(MNIST):
223244
target_transform (callable, optional): A function/transform that takes in the
224245
target and transforms it.
225246
"""
226-
url = 'http://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip'
247+
# Updated URL from https://www.westernsydney.edu.au/bens/home/reproducible_research/emnist
248+
url = 'https://cloudstor.aarnet.edu.au/plus/index.php/s/54h3OuGJhFLwAlQ/download'
227249
splits = ('byclass', 'bymerge', 'balanced', 'letters', 'digits', 'mnist')
228250

229251
def __init__(self, root, split, **kwargs):

torchvision/models/alexnet.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def __init__(self, num_classes=1000):
2929
nn.ReLU(inplace=True),
3030
nn.MaxPool2d(kernel_size=3, stride=2),
3131
)
32+
self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
3233
self.classifier = nn.Sequential(
3334
nn.Dropout(),
3435
nn.Linear(256 * 6 * 6, 4096),
@@ -41,6 +42,7 @@ def __init__(self, num_classes=1000):
4142

4243
def forward(self, x):
4344
x = self.features(x)
45+
x = self.avgpool(x)
4446
x = x.view(x.size(0), 256 * 6 * 6)
4547
x = self.classifier(x)
4648
return x

torchvision/models/inception.py

Lines changed: 39 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ def inception_v3(pretrained=False, **kwargs):
1717
r"""Inception v3 model architecture from
1818
`"Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>`_.
1919
20+
.. note::
21+
**Important**: In contrast to the other models the inception_v3 expects tensors with a size of
22+
N x 3 x 299 x 299, so ensure your images are sized accordingly.
23+
2024
Args:
2125
pretrained (bool): If True, returns a model pre-trained on ImageNet
2226
"""
@@ -74,54 +78,55 @@ def forward(self, x):
7478
x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
7579
x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
7680
x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
77-
# 299 x 299 x 3
81+
# N x 3 x 299 x 299
7882
x = self.Conv2d_1a_3x3(x)
79-
# 149 x 149 x 32
83+
# N x 32 x 149 x 149
8084
x = self.Conv2d_2a_3x3(x)
81-
# 147 x 147 x 32
85+
# N x 32 x 147 x 147
8286
x = self.Conv2d_2b_3x3(x)
83-
# 147 x 147 x 64
87+
# N x 64 x 147 x 147
8488
x = F.max_pool2d(x, kernel_size=3, stride=2)
85-
# 73 x 73 x 64
89+
# N x 64 x 73 x 73
8690
x = self.Conv2d_3b_1x1(x)
87-
# 73 x 73 x 80
91+
# N x 80 x 73 x 73
8892
x = self.Conv2d_4a_3x3(x)
89-
# 71 x 71 x 192
93+
# N x 192 x 71 x 71
9094
x = F.max_pool2d(x, kernel_size=3, stride=2)
91-
# 35 x 35 x 192
95+
# N x 192 x 35 x 35
9296
x = self.Mixed_5b(x)
93-
# 35 x 35 x 256
97+
# N x 256 x 35 x 35
9498
x = self.Mixed_5c(x)
95-
# 35 x 35 x 288
99+
# N x 288 x 35 x 35
96100
x = self.Mixed_5d(x)
97-
# 35 x 35 x 288
101+
# N x 288 x 35 x 35
98102
x = self.Mixed_6a(x)
99-
# 17 x 17 x 768
103+
# N x 768 x 17 x 17
100104
x = self.Mixed_6b(x)
101-
# 17 x 17 x 768
105+
# N x 768 x 17 x 17
102106
x = self.Mixed_6c(x)
103-
# 17 x 17 x 768
107+
# N x 768 x 17 x 17
104108
x = self.Mixed_6d(x)
105-
# 17 x 17 x 768
109+
# N x 768 x 17 x 17
106110
x = self.Mixed_6e(x)
107-
# 17 x 17 x 768
111+
# N x 768 x 17 x 17
108112
if self.training and self.aux_logits:
109113
aux = self.AuxLogits(x)
110-
# 17 x 17 x 768
114+
# N x 768 x 17 x 17
111115
x = self.Mixed_7a(x)
112-
# 8 x 8 x 1280
116+
# N x 1280 x 8 x 8
113117
x = self.Mixed_7b(x)
114-
# 8 x 8 x 2048
118+
# N x 2048 x 8 x 8
115119
x = self.Mixed_7c(x)
116-
# 8 x 8 x 2048
117-
x = F.avg_pool2d(x, kernel_size=8)
118-
# 1 x 1 x 2048
120+
# N x 2048 x 8 x 8
121+
# Adaptive average pooling
122+
x = F.adaptive_avg_pool2d(x, (1, 1))
123+
# N x 2048 x 1 x 1
119124
x = F.dropout(x, training=self.training)
120-
# 1 x 1 x 2048
125+
# N x 2048 x 1 x 1
121126
x = x.view(x.size(0), -1)
122-
# 2048
127+
# N x 2048
123128
x = self.fc(x)
124-
# 1000 (num_classes)
129+
# N x 1000 (num_classes)
125130
if self.training and self.aux_logits:
126131
return x, aux
127132
return x
@@ -300,17 +305,20 @@ def __init__(self, in_channels, num_classes):
300305
self.fc.stddev = 0.001
301306

302307
def forward(self, x):
303-
# 17 x 17 x 768
308+
# N x 768 x 17 x 17
304309
x = F.avg_pool2d(x, kernel_size=5, stride=3)
305-
# 5 x 5 x 768
310+
# N x 768 x 5 x 5
306311
x = self.conv0(x)
307-
# 5 x 5 x 128
312+
# N x 128 x 5 x 5
308313
x = self.conv1(x)
309-
# 1 x 1 x 768
314+
# N x 768 x 1 x 1
315+
# Adaptive average pooling
316+
x = F.adaptive_avg_pool2d(x, (1, 1))
317+
# N x 768 x 1 x 1
310318
x = x.view(x.size(0), -1)
311-
# 768
319+
# N x 768
312320
x = self.fc(x)
313-
# 1000
321+
# N x 1000
314322
return x
315323

316324

torchvision/models/vgg.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class VGG(nn.Module):
2525
def __init__(self, features, num_classes=1000, init_weights=True):
2626
super(VGG, self).__init__()
2727
self.features = features
28+
self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
2829
self.classifier = nn.Sequential(
2930
nn.Linear(512 * 7 * 7, 4096),
3031
nn.ReLU(True),
@@ -39,6 +40,7 @@ def __init__(self, features, num_classes=1000, init_weights=True):
3940

4041
def forward(self, x):
4142
x = self.features(x)
43+
x = self.avgpool(x)
4244
x = x.view(x.size(0), -1)
4345
x = self.classifier(x)
4446
return x

torchvision/transforms/functional.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def to_pil_image(pic, mode=None):
117117

118118
elif pic.ndimension() == 2:
119119
# if 2D image, add channel dimension (CHW)
120-
pic.unsqueeze_(0)
120+
pic = pic.unsqueeze(0)
121121

122122
elif isinstance(pic, np.ndarray):
123123
if pic.ndim not in {2, 3}:
@@ -376,8 +376,8 @@ def resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR):
376376
377377
Args:
378378
img (PIL Image): Image to be cropped.
379-
i: Upper pixel coordinate.
380-
j: Left pixel coordinate.
379+
i: i in (i,j) i.e coordinates of the upper left corner
380+
j: j in (i,j) i.e coordinates of the upper left corner
381381
h: Height of the cropped image.
382382
w: Width of the cropped image.
383383
size (sequence or int): Desired output size. Same semantics as ``resize``.

torchvision/transforms/transforms.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -543,7 +543,13 @@ class RandomResizedCrop(object):
543543
"""
544544

545545
def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR):
546-
self.size = (size, size)
546+
if isinstance(size, tuple):
547+
self.size = size
548+
else:
549+
self.size = (size, size)
550+
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
551+
warnings.warn("range should be of kind (min, max)")
552+
547553
self.interpolation = interpolation
548554
self.scale = scale
549555
self.ratio = ratio
@@ -570,7 +576,7 @@ def get_params(img, scale, ratio):
570576
w = int(round(math.sqrt(target_area * aspect_ratio)))
571577
h = int(round(math.sqrt(target_area / aspect_ratio)))
572578

573-
if random.random() < 0.5:
579+
if random.random() < 0.5 and min(ratio) <= (h / w) <= max(ratio):
574580
w, h = h, w
575581

576582
if w <= img.size[0] and h <= img.size[1]:

torchvision/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def norm_range(t, range):
7474
xmaps = min(nrow, nmaps)
7575
ymaps = int(math.ceil(float(nmaps) / xmaps))
7676
height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding)
77-
grid = tensor.new(3, height * ymaps + padding, width * xmaps + padding).fill_(pad_value)
77+
grid = tensor.new_full((3, height * ymaps + padding, width * xmaps + padding), pad_value)
7878
k = 0
7979
for y in irange(ymaps):
8080
for x in irange(xmaps):
@@ -99,6 +99,7 @@ def save_image(tensor, filename, nrow=8, padding=2,
9999
from PIL import Image
100100
grid = make_grid(tensor, nrow=nrow, padding=padding, pad_value=pad_value,
101101
normalize=normalize, range=range, scale_each=scale_each)
102-
ndarr = grid.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()
102+
# Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
103+
ndarr = grid.mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
103104
im = Image.fromarray(ndarr)
104105
im.save(filename)

0 commit comments

Comments
 (0)