diff --git a/test/test_transforms.py b/test/test_transforms.py index 5d0275f946f..84d14d0e95f 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1,6 +1,7 @@ import torch import torchvision.transforms as transforms import unittest +from parameterized import parameterized import random import numpy as np from PIL import Image @@ -171,17 +172,35 @@ def test_lambda(self): y = trans(x) assert (y.equal(x)) - def test_to_tensor(self): - channels = 3 - height, width = 4, 4 - trans = transforms.ToTensor() - input_data = torch.ByteTensor(channels, height, width).random_(0, 255).float().div_(255) + @parameterized.expand([ + ('3channel', 3, 4, 4), + ('1channel', 1, 4, 4), + ]) + def test_pil_to_tensor(self, _, channels, height, width): + input_data = torch.ByteTensor(channels, height, width) + input_data = input_data.random_(0, 255).float().div_(255) img = transforms.ToPILImage()(input_data) - output = trans(img) + output = transforms.ToTensor()(img) assert np.allclose(input_data.numpy(), output.numpy()) - ndarray = np.random.randint(low=0, high=255, size=(height, width, channels)) - output = trans(ndarray) + @parameterized.expand([ + ('smoke', 4, 4), + ]) + def test_ndarray_to_tensor_2dim(self, _, height, width): + ndarray_size = (height, width) + ndarray = np.random.randint(low=0, high=255, size=ndarray_size) + output = transforms.ToTensor()(ndarray) + expected_output = ndarray[..., np.newaxis].transpose((2, 0, 1)) / 255.0 + assert np.allclose(output.numpy(), expected_output) + + @parameterized.expand([ + ('1channel', 1, 4, 4), + ('3channel', 3, 4, 4), + ]) + def test_ndarray_to_tensor_3dim(self, _, channels, height, width): + ndarray_size = (height, width, channels) + ndarray = np.random.randint(low=0, high=255, size=ndarray_size) + output = transforms.ToTensor()(ndarray) expected_output = ndarray.transpose((2, 0, 1)) / 255.0 assert np.allclose(output.numpy(), expected_output) diff --git a/torchvision/transforms.py b/torchvision/transforms.py index da58aa12b9a..0e091885824 100644 --- a/torchvision/transforms.py +++ b/torchvision/transforms.py @@ -52,33 +52,38 @@ def __call__(self, pic): """ if isinstance(pic, np.ndarray): # handle numpy array - img = torch.from_numpy(pic.transpose((2, 0, 1))) + if pic.ndim == 2: + pic = pic[np.newaxis] + elif pic.ndim == 3: + pic.transpose((2, 0, 1)) + else: + raise ValueError('only 2D and 3D images accepted, got {}D image'.format(pic.ndim) # backward compatibility - return img.float().div(255) + return torch.from_numpy(pic).float().div(255.) if accimage is not None and isinstance(pic, accimage.Image): - nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32) + nppic=np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32) pic.copyto(nppic) return torch.from_numpy(nppic) # handle PIL Image if pic.mode == 'I': - img = torch.from_numpy(np.array(pic, np.int32, copy=False)) + img=torch.from_numpy(np.array(pic, np.int32, copy=False)) elif pic.mode == 'I;16': - img = torch.from_numpy(np.array(pic, np.int16, copy=False)) + img=torch.from_numpy(np.array(pic, np.int16, copy=False)) else: - img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) + img=torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK if pic.mode == 'YCbCr': - nchannel = 3 + nchannel=3 elif pic.mode == 'I;16': - nchannel = 1 + nchannel=1 else: - nchannel = len(pic.mode) - img = img.view(pic.size[1], pic.size[0], nchannel) + nchannel=len(pic.mode) + img=img.view(pic.size[1], pic.size[0], nchannel) # put it from HWC to CHW format # yikes, this transpose takes 80% of the loading time/CPU - img = img.transpose(0, 1).transpose(0, 2).contiguous() + img=img.transpose(0, 1).transpose(0, 2).contiguous() if isinstance(img, torch.ByteTensor): return img.float().div(255) else: @@ -101,30 +106,32 @@ def __call__(self, pic): PIL.Image: Image converted to PIL.Image. """ - npimg = pic - mode = None + npimg=pic + mode=None if isinstance(pic, torch.FloatTensor): - pic = pic.mul(255).byte() + pic=pic.mul(255).byte() if torch.is_tensor(pic): - npimg = np.transpose(pic.numpy(), (1, 2, 0)) + npimg=np.transpose(pic.numpy(), (1, 2, 0)) assert isinstance(npimg, np.ndarray), 'pic should be Tensor or ndarray' + if len(npimg.shape) < 3: + npimg=np.reshape(npimg, npimg.shape + (1,)) if npimg.shape[2] == 1: - npimg = npimg[:, :, 0] + npimg=npimg[:, :, 0] if npimg.dtype == np.uint8: - mode = 'L' + mode='L' if npimg.dtype == np.int16: - mode = 'I;16' + mode='I;16' if npimg.dtype == np.int32: - mode = 'I' + mode='I' elif npimg.dtype == np.float32: - mode = 'F' + mode='F' elif npimg.shape[2] == 4: if npimg.dtype == np.uint8: - mode = 'RGBA' + mode='RGBA' else: if npimg.dtype == np.uint8: - mode = 'RGB' + mode='RGB' assert mode is not None, '{} is not supported'.format(npimg.dtype) return Image.fromarray(npimg, mode=mode) @@ -143,8 +150,8 @@ class Normalize(object): """ def __init__(self, mean, std): - self.mean = mean - self.std = std + self.mean=mean + self.std=std def __call__(self, tensor): """ @@ -175,8 +182,8 @@ class Scale(object): def __init__(self, size, interpolation=Image.BILINEAR): assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) - self.size = size - self.interpolation = interpolation + self.size=size + self.interpolation=interpolation def __call__(self, img): """ @@ -187,16 +194,16 @@ def __call__(self, img): PIL.Image: Rescaled image. """ if isinstance(self.size, int): - w, h = img.size + w, h=img.size if (w <= h and w == self.size) or (h <= w and h == self.size): return img if w < h: - ow = self.size - oh = int(self.size * h / w) + ow=self.size + oh=int(self.size * h / w) return img.resize((ow, oh), self.interpolation) else: - oh = self.size - ow = int(self.size * w / h) + oh=self.size + ow=int(self.size * w / h) return img.resize((ow, oh), self.interpolation) else: return img.resize(self.size[::-1], self.interpolation) @@ -213,9 +220,9 @@ class CenterCrop(object): def __init__(self, size): if isinstance(size, numbers.Number): - self.size = (int(size), int(size)) + self.size=(int(size), int(size)) else: - self.size = size + self.size=size def __call__(self, img): """ @@ -225,10 +232,10 @@ def __call__(self, img): Returns: PIL.Image: Cropped image. """ - w, h = img.size - th, tw = self.size - x1 = int(round((w - tw) / 2.)) - y1 = int(round((h - th) / 2.)) + w, h=img.size + th, tw=self.size + x1=int(round((w - tw) / 2.)) + y1=int(round((h - th) / 2.)) return img.crop((x1, y1, x1 + tw, y1 + th)) @@ -252,8 +259,8 @@ def __init__(self, padding, fill=0): raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " + "{} element tuple".format(len(padding))) - self.padding = padding - self.fill = fill + self.padding=padding + self.fill=fill def __call__(self, img): """ @@ -275,7 +282,7 @@ class Lambda(object): def __init__(self, lambd): assert isinstance(lambd, types.LambdaType) - self.lambd = lambd + self.lambd=lambd def __call__(self, img): return self.lambd(img) @@ -296,10 +303,10 @@ class RandomCrop(object): def __init__(self, size, padding=0): if isinstance(size, numbers.Number): - self.size = (int(size), int(size)) + self.size=(int(size), int(size)) else: - self.size = size - self.padding = padding + self.size=size + self.padding=padding def __call__(self, img): """ @@ -310,15 +317,15 @@ def __call__(self, img): PIL.Image: Cropped image. """ if self.padding > 0: - img = ImageOps.expand(img, border=self.padding, fill=0) + img=ImageOps.expand(img, border=self.padding, fill=0) - w, h = img.size - th, tw = self.size + w, h=img.size + th, tw=self.size if w == tw and h == th: return img - x1 = random.randint(0, w - tw) - y1 = random.randint(0, h - th) + x1=random.randint(0, w - tw) + y1=random.randint(0, h - th) return img.crop((x1, y1, x1 + tw, y1 + th)) @@ -352,31 +359,31 @@ class RandomSizedCrop(object): """ def __init__(self, size, interpolation=Image.BILINEAR): - self.size = size - self.interpolation = interpolation + self.size=size + self.interpolation=interpolation def __call__(self, img): for attempt in range(10): - area = img.size[0] * img.size[1] - target_area = random.uniform(0.08, 1.0) * area - aspect_ratio = random.uniform(3. / 4, 4. / 3) + area=img.size[0] * img.size[1] + target_area=random.uniform(0.08, 1.0) * area + aspect_ratio=random.uniform(3. / 4, 4. / 3) - w = int(round(math.sqrt(target_area * aspect_ratio))) - h = int(round(math.sqrt(target_area / aspect_ratio))) + w=int(round(math.sqrt(target_area * aspect_ratio))) + h=int(round(math.sqrt(target_area / aspect_ratio))) if random.random() < 0.5: - w, h = h, w + w, h=h, w if w <= img.size[0] and h <= img.size[1]: - x1 = random.randint(0, img.size[0] - w) - y1 = random.randint(0, img.size[1] - h) + x1=random.randint(0, img.size[0] - w) + y1=random.randint(0, img.size[1] - h) - img = img.crop((x1, y1, x1 + w, y1 + h)) + img=img.crop((x1, y1, x1 + w, y1 + h)) assert(img.size == (w, h)) return img.resize((self.size, self.size), self.interpolation) # Fallback - scale = Scale(self.size, interpolation=self.interpolation) - crop = CenterCrop(self.size) + scale=Scale(self.size, interpolation=self.interpolation) + crop=CenterCrop(self.size) return crop(scale(img))