Skip to content

Commit e6d3f8c

Browse files
authored
Add pil_to_tensor to functionals (#2092)
* Adds as_tensor to functional.py Similar functionality to to_tensor without the default conversion to float and division by 255. Also adds support for Image mode 'L'. * Adds tests to AsTensor() Adds tests to AsTensor and removes the conversion to float and division by 255. * Adds AsTensor to transforms.py Calls the as_tensor function in functionals and adds the function AsTensor as callable from transforms. * Removes the pic.mode == 'L' This was handled by the else condition previously so I'll remove it. * Fix Lint issue Adds two line breaks between functions to fix lint issue * Replace from_numpy with as_tensor Removes the extra if conditionals and replaces from_numpy with as_tensor. * Renames as_tensor to pil_to_tensor Renames the function as_tensor to pil_to_tensor and narrows the scope of the function. At the same time also creates a flag that defaults to True for swapping to the channels first format. * Renames AsTensor to PILToImage Renames the function AsTensor to PILToImage and modifies the description. Adds the swap_to_channelsfirst boolean variable to indicate if the user wishes to change the shape of the input. * Add the __init__ function to PILToTensor Add the __init__ function to PILToTensor since it contains the swap_to_channelsfirst parameter now. * fix lint issue remove trailing white space * Fix the tests Reflects the name change to PILToTensor and the parameter to the function as well as the new narrowed scope that the function only accepts PIL images. * fix tests Instead of undoing the transpose just create a new tensor and test that one. * Add the view back Add img.view(pic.size[1], pic.size[0], len(pic.getbands())) back to outside the if condition. * fix test fix conversion from torch tensor to PIL back to torch tensor. * fix lint issues * fix lint remove trailing white space * Fixed the channel swapping tensor test Torch tranpose operates differently than numpy transpose. Changed operation to permute. * Add mode='F' Add mode information when converting to PIL Image from Float Tensor. * Added inline comments to follow shape changes * ToPILImage converts FloatTensors to uint8 * Remove testing not swapping * Removes the swap_channelsfirst parameter Makes the channel swapping the default behavior. * Remove the swap_channelsfirst argument Remove the swap_channelsfirst argument and makes the swapping the default functionality.
1 parent e2e511b commit e6d3f8c

File tree

3 files changed

+91
-1
lines changed

3 files changed

+91
-1
lines changed

test/test_transforms.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,49 @@ def test_accimage_to_tensor(self):
511511
self.assertEqual(expected_output.size(), output.size())
512512
self.assertTrue(np.allclose(output.numpy(), expected_output.numpy()))
513513

514+
def test_pil_to_tensor(self):
515+
test_channels = [1, 3, 4]
516+
height, width = 4, 4
517+
trans = transforms.PILToTensor()
518+
519+
with self.assertRaises(TypeError):
520+
trans(np.random.rand(1, height, width).tolist())
521+
trans(np.random.rand(1, height, width))
522+
523+
for channels in test_channels:
524+
input_data = torch.ByteTensor(channels, height, width).random_(0, 255)
525+
img = transforms.ToPILImage()(input_data)
526+
output = trans(img)
527+
self.assertTrue(np.allclose(input_data.numpy(), output.numpy()))
528+
529+
input_data = np.random.randint(low=0, high=255, size=(height, width, channels)).astype(np.uint8)
530+
img = transforms.ToPILImage()(input_data)
531+
output = trans(img)
532+
expected_output = input_data.transpose((2, 0, 1))
533+
self.assertTrue(np.allclose(output.numpy(), expected_output))
534+
535+
input_data = torch.as_tensor(np.random.rand(channels, height, width).astype(np.float32))
536+
img = transforms.ToPILImage()(input_data) # CHW -> HWC and (* 255).byte()
537+
output = trans(img) # HWC -> CHW
538+
expected_output = (input_data * 255).byte()
539+
self.assertTrue(np.allclose(output.numpy(), expected_output.numpy()))
540+
541+
# separate test for mode '1' PIL images
542+
input_data = torch.ByteTensor(1, height, width).bernoulli_()
543+
img = transforms.ToPILImage()(input_data.mul(255)).convert('1')
544+
output = trans(img)
545+
self.assertTrue(np.allclose(input_data.numpy(), output.numpy()))
546+
547+
@unittest.skipIf(accimage is None, 'accimage not available')
548+
def test_accimage_pil_to_tensor(self):
549+
trans = transforms.PILToTensor()
550+
551+
expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB'))
552+
output = trans(accimage.Image(GRACE_HOPPER))
553+
554+
self.assertEqual(expected_output.size(), output.size())
555+
self.assertTrue(np.allclose(output.numpy(), expected_output.numpy()))
556+
514557
@unittest.skipIf(accimage is None, 'accimage not available')
515558
def test_accimage_resize(self):
516559
trans = transforms.Compose([

torchvision/transforms/functional.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,33 @@ def to_tensor(pic):
8282
return img
8383

8484

85+
def pil_to_tensor(pic):
86+
"""Convert a ``PIL Image`` to a tensor of the same type.
87+
88+
See ``AsTensor`` for more details.
89+
90+
Args:
91+
pic (PIL Image): Image to be converted to tensor.
92+
93+
Returns:
94+
Tensor: Converted image.
95+
"""
96+
if not(_is_pil_image(pic)):
97+
raise TypeError('pic should be PIL Image. Got {}'.format(type(pic)))
98+
99+
if accimage is not None and isinstance(pic, accimage.Image):
100+
nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
101+
pic.copyto(nppic)
102+
return torch.as_tensor(nppic)
103+
104+
# handle PIL Image
105+
img = torch.as_tensor(np.asarray(pic))
106+
img = img.view(pic.size[1], pic.size[0], len(pic.getbands()))
107+
# put it from HWC to CHW format
108+
img = img.permute((2, 0, 1))
109+
return img
110+
111+
85112
def to_pil_image(pic, mode=None):
86113
"""Convert a tensor or an ndarray to PIL Image.
87114

torchvision/transforms/transforms.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from . import functional as F
1616

1717

18-
__all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad",
18+
__all__ = ["Compose", "ToTensor", "PILToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad",
1919
"Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip",
2020
"RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation",
2121
"ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale",
@@ -95,6 +95,26 @@ def __repr__(self):
9595
return self.__class__.__name__ + '()'
9696

9797

98+
class PILToTensor(object):
99+
"""Convert a ``PIL Image`` to a tensor of the same type.
100+
101+
Converts a PIL Image (H x W x C) to a torch.Tensor of shape (C x H x W).
102+
"""
103+
104+
def __call__(self, pic):
105+
"""
106+
Args:
107+
pic (PIL Image): Image to be converted to tensor.
108+
109+
Returns:
110+
Tensor: Converted image.
111+
"""
112+
return F.pil_to_tensor(pic)
113+
114+
def __repr__(self):
115+
return self.__class__.__name__ + '()'
116+
117+
98118
class ToPILImage(object):
99119
"""Convert a tensor or an ndarray to PIL Image.
100120

0 commit comments

Comments
 (0)