Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 19 additions & 20 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,28 @@
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import numpy as np
import unittest
import random


class Tester(unittest.TestCase):
def test_crop(self):
height = random.randint(10, 32) * 2
width = random.randint(10, 32) * 2
oheight = random.randint(5, (height - 2) / 2) * 2
owidth = random.randint(5, (width - 2) / 2) * 2
owidth = random.randint(5, (width - 2) / 2) * 2

img = torch.ones(3, height, width)
oh1 = (height - oheight) / 2
ow1 = (width - owidth) / 2
imgnarrow = img[:, oh1 :oh1 + oheight, ow1 :ow1 + owidth]
oh1 = (height - oheight) // 2
ow1 = (width - owidth) // 2
imgnarrow = img[:, oh1:oh1 + oheight, ow1:ow1 + owidth]
imgnarrow.fill_(0)
result = transforms.Compose([
transforms.ToPILImage(),
transforms.CenterCrop((oheight, owidth)),
transforms.ToTensor(),
])(img)
assert result.sum() == 0, "height: " + str(height) + " width: " \
+ str( width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
+ str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
oheight += 1
owidth += 1
result = transforms.Compose([
Expand All @@ -33,25 +32,25 @@ def test_crop(self):
])(img)
sum1 = result.sum()
assert sum1 > 1, "height: " + str(height) + " width: " \
+ str( width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
+ str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
oheight += 1
owidth += 1
owidth += 1
result = transforms.Compose([
transforms.ToPILImage(),
transforms.CenterCrop((oheight, owidth)),
transforms.ToTensor(),
])(img)
sum2 = result.sum()
assert sum2 > 0, "height: " + str(height) + " width: " \
+ str( width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
+ str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
assert sum2 > sum1, "height: " + str(height) + " width: " \
+ str( width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
+ str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)

def test_scale(self):
height = random.randint(24, 32) * 2
width = random.randint(24, 32) * 2
osize = random.randint(5, 12) * 2

img = torch.ones(3, height, width)
result = transforms.Compose([
transforms.ToPILImage(),
Expand All @@ -63,15 +62,15 @@ def test_scale(self):
# print result.size()
assert osize in result.size()
if height < width:
assert result.size(1) <= result.size(2)
assert result.size(1) <= result.size(2)
elif width < height:
assert result.size(1) >= result.size(2)

def test_random_crop(self):
height = random.randint(10, 32) * 2
width = random.randint(10, 32) * 2
oheight = random.randint(5, (height - 2) / 2) * 2
owidth = random.randint(5, (width - 2) / 2) * 2
owidth = random.randint(5, (width - 2) / 2) * 2
img = torch.ones(3, height, width)
result = transforms.Compose([
transforms.ToPILImage(),
Expand Down Expand Up @@ -100,20 +99,20 @@ def test_pad(self):
transforms.Pad(padding),
transforms.ToTensor(),
])(img)
assert result.size(1) == height + 2*padding
assert result.size(2) == width + 2*padding
assert result.size(1) == height + 2 * padding
assert result.size(2) == width + 2 * padding

def test_lambda(self):
trans = transforms.Lambda(lambda x: x.add(10))
x = torch.randn(10)
y = trans(x)
assert(y.equal(torch.add(x, 10)))
assert (y.equal(torch.add(x, 10)))

trans = transforms.Lambda(lambda x: x.add_(10))
x = torch.randn(10)
y = trans(x)
assert(y.equal(x))
assert (y.equal(x))


if __name__ == '__main__':
unittest.main()