Skip to content

Commit f6d49d8

Browse files
yanghuanflcfmassa
authored andcommitted
Add support for Transforms.Scale([w, h]) with specific width and height (#133)
* Fixed border missing on bottom and right side using make_grid * Add support for Transforms.Scale([h, w]) with specific height and width if self.size is a int then scale the image with the shorter side, otherwise if self.size is a list then scale the image to self.size directly * Add assert of size and doc for README Add assert of size and doc for README * Fix linter problem Fix linter problem * Add test for Scale Add test for Scale * Add both tuple and list support for Scale.size Add both tuple and list support for Scale.size * Add order of Scale.size in document and test case for list type of Scale.size Add order of Scale.size in document and test case for list type of Scale.size
1 parent de5dcb9 commit f6d49d8

File tree

3 files changed

+40
-14
lines changed

3 files changed

+40
-14
lines changed

README.rst

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,9 +307,11 @@ Transforms on PIL.Image
307307
``Scale(size, interpolation=Image.BILINEAR)``
308308
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
309309

310-
Rescales the input PIL.Image to the given 'size'. 'size' will be the
311-
size of the smaller edge.
310+
Rescales the input PIL.Image to the given 'size'.
312311

312+
If 'size' is a 2-element tuple or list in the order of (width, height), it will be the exactly size to scale.
313+
314+
If 'size' is a number, it will indicate the size of the smaller edge.
313315
For example, if height > width, then image will be rescaled to (size \*
314316
height / width, size) - size: size of the smaller edge - interpolation:
315317
Default: PIL.Image.BILINEAR

test/test_transforms.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,24 @@ def test_scale(self):
6868
elif width < height:
6969
assert result.size(1) >= result.size(2)
7070

71+
oheight = random.randint(5, 12) * 2
72+
owidth = random.randint(5, 12) * 2
73+
result = transforms.Compose([
74+
transforms.ToPILImage(),
75+
transforms.Scale((owidth, oheight)),
76+
transforms.ToTensor(),
77+
])(img)
78+
assert result.size(1) == oheight
79+
assert result.size(2) == owidth
80+
81+
result = transforms.Compose([
82+
transforms.ToPILImage(),
83+
transforms.Scale([owidth, oheight]),
84+
transforms.ToTensor(),
85+
])(img)
86+
assert result.size(1) == oheight
87+
assert result.size(2) == owidth
88+
7189
def test_random_crop(self):
7290
height = random.randint(10, 32) * 2
7391
width = random.randint(10, 32) * 2

torchvision/transforms.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77
import numbers
88
import types
9+
import collections
910

1011

1112
class Compose(object):
@@ -115,29 +116,34 @@ def __call__(self, tensor):
115116

116117
class Scale(object):
117118
"""Rescales the input PIL.Image to the given 'size'.
118-
'size' will be the size of the smaller edge.
119+
If 'size' is a 2-element tuple or list in the order of (width, height), it will be the exactly size to scale.
120+
If 'size' is a number, it will indicate the size of the smaller edge.
119121
For example, if height > width, then image will be
120122
rescaled to (size * height / width, size)
121-
size: size of the smaller edge
123+
size: size of the exactly size or the smaller edge
122124
interpolation: Default: PIL.Image.BILINEAR
123125
"""
124126

125127
def __init__(self, size, interpolation=Image.BILINEAR):
128+
assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
126129
self.size = size
127130
self.interpolation = interpolation
128131

129132
def __call__(self, img):
130-
w, h = img.size
131-
if (w <= h and w == self.size) or (h <= w and h == self.size):
132-
return img
133-
if w < h:
134-
ow = self.size
135-
oh = int(self.size * h / w)
136-
return img.resize((ow, oh), self.interpolation)
133+
if isinstance(self.size, int):
134+
w, h = img.size
135+
if (w <= h and w == self.size) or (h <= w and h == self.size):
136+
return img
137+
if w < h:
138+
ow = self.size
139+
oh = int(self.size * h / w)
140+
return img.resize((ow, oh), self.interpolation)
141+
else:
142+
oh = self.size
143+
ow = int(self.size * w / h)
144+
return img.resize((ow, oh), self.interpolation)
137145
else:
138-
oh = self.size
139-
ow = int(self.size * w / h)
140-
return img.resize((ow, oh), self.interpolation)
146+
return img.resize(self.size, self.interpolation)
141147

142148

143149
class CenterCrop(object):

0 commit comments

Comments
 (0)