diff --git a/torchvision/transforms.py b/torchvision/transforms.py index 48be812569b..f19f07288d0 100644 --- a/torchvision/transforms.py +++ b/torchvision/transforms.py @@ -34,7 +34,7 @@ def __call__(self, pic): else: # handle PIL Image img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) - img = img.view(pic.size[1], pic.size[0], 3) + img = img.view(pic.size[1], pic.size[0], len(pic.mode)) # put it from HWC to CHW format # yikes, this transpose takes 80% of the loading time/CPU img = img.transpose(0, 1).transpose(0, 2).contiguous()