Skip to content

Commit 4390b55

Browse files
committed
Make get_params static method
1 parent 8b18f52 commit 4390b55

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

torchvision/transforms.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def to_tensor(pic):
6767
return img
6868

6969

70-
def to_pilimage(pic):
70+
def to_pil_image(pic):
7171
if not(_is_numpy_image(pic) or _is_tensor_image(pic)):
7272
raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic)))
7373

@@ -219,7 +219,7 @@ def __call__(self, pic):
219219
PIL.Image: Image converted to PIL.Image.
220220
221221
"""
222-
return to_pilimage(pic)
222+
return to_pil_image(pic)
223223

224224

225225
class Normalize(object):
@@ -294,9 +294,10 @@ def __init__(self, size):
294294
else:
295295
self.size = size
296296

297-
def get_params(self, img):
297+
@staticmethod
298+
def get_params(img, output_size):
298299
w, h = img.size
299-
th, tw = self.size
300+
th, tw = output_size
300301
x1 = int(round((w - tw) / 2.))
301302
y1 = int(round((h - th) / 2.))
302303
return x1, y1, tw, th
@@ -309,7 +310,7 @@ def __call__(self, img):
309310
Returns:
310311
PIL.Image: Cropped image.
311312
"""
312-
x1, y1, tw, th = self.get_params(img)
313+
x1, y1, tw, th = self.get_params(img, self.size)
313314
return crop(img, x1, y1, tw, th)
314315

315316

@@ -382,9 +383,10 @@ def __init__(self, size, padding=0):
382383
self.size = size
383384
self.padding = padding
384385

385-
def get_params(self, img):
386+
@staticmethod
387+
def get_params(img, output_size):
386388
w, h = img.size
387-
th, tw = self.size
389+
th, tw = output_size
388390
if w == tw and h == th:
389391
return img
390392

@@ -403,7 +405,7 @@ def __call__(self, img):
403405
if self.padding > 0:
404406
img = pad(img, self.padding)
405407

406-
x1, y1, tw, th = self.get_params(img)
408+
x1, y1, tw, th = self.get_params(img, self.size)
407409

408410
return crop(img, x1, y1, tw, th)
409411

@@ -441,7 +443,8 @@ def __init__(self, size, interpolation=Image.BILINEAR):
441443
self.size = size
442444
self.interpolation = interpolation
443445

444-
def get_params(self, img):
446+
@staticmethod
447+
def get_params(img):
445448
for attempt in range(10):
446449
area = img.size[0] * img.size[1]
447450
target_area = random.uniform(0.08, 1.0) * area

0 commit comments

Comments
 (0)