Skip to content

Commit a51c49e

Browse files
ademyanchukDemyanchukvfdev-5
authored
Add explicit check for number of channels (#3013)
* Add explicit check for number of channels Example why you need to check it: `M = torch.randint(low=0, high=2, size=(6, 64, 64), dtype = torch.float)` When you put this input through to_pil_image without mode argument, it converts to uint8 here: ``` if pic.is_floating_point() and mode != 'F': pic = pic.mul(255).byte() ``` and change the mode to RGB here: ``` if mode is None and npimg.dtype == np.uint8: mode = 'RGB' ``` Image.fromarray doesn't raise if provided with mode RGB and just cut number of channels from what you have to 3 * Check number of channels before processing * Add test for invalid number of channels * Add explicit check for number of channels Example why you need to check it: `M = torch.randint(low=0, high=2, size=(6, 64, 64), dtype = torch.float)` When you put this input through to_pil_image without mode argument, it converts to uint8 here: ``` if pic.is_floating_point() and mode != 'F': pic = pic.mul(255).byte() ``` and change the mode to RGB here: ``` if mode is None and npimg.dtype == np.uint8: mode = 'RGB' ``` Image.fromarray doesn't raise if provided with mode RGB and just cut number of channels from what you have to 3 * Check number of channels before processing * Add test for invalid number of channels * Put check after channel dim unsqueeze * Add test if error message is matching * Delete redundant code * Bug fix in checking for bad types Co-authored-by: Demyanchuk <[email protected]> Co-authored-by: vfdev <[email protected]>
1 parent cd0268c commit a51c49e

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

test/test_transforms.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -987,19 +987,27 @@ def test_2d_ndarray_to_pil_image(self):
987987
self.assertTrue(np.allclose(img_data, img))
988988

989989
def test_tensor_bad_types_to_pil_image(self):
990-
with self.assertRaises(ValueError):
990+
with self.assertRaisesRegex(ValueError, r'pic should be 2/3 dimensional. Got \d+ dimensions.'):
991991
transforms.ToPILImage()(torch.ones(1, 3, 4, 4))
992+
with self.assertRaisesRegex(ValueError, r'pic should not have > 4 channels. Got \d+ channels.'):
993+
transforms.ToPILImage()(torch.ones(6, 4, 4))
992994

993995
def test_ndarray_bad_types_to_pil_image(self):
994996
trans = transforms.ToPILImage()
995-
with self.assertRaises(TypeError):
997+
reg_msg = r'Input type \w+ is not supported'
998+
with self.assertRaisesRegex(TypeError, reg_msg):
996999
trans(np.ones([4, 4, 1], np.int64))
1000+
with self.assertRaisesRegex(TypeError, reg_msg):
9971001
trans(np.ones([4, 4, 1], np.uint16))
1002+
with self.assertRaisesRegex(TypeError, reg_msg):
9981003
trans(np.ones([4, 4, 1], np.uint32))
1004+
with self.assertRaisesRegex(TypeError, reg_msg):
9991005
trans(np.ones([4, 4, 1], np.float64))
10001006

1001-
with self.assertRaises(ValueError):
1007+
with self.assertRaisesRegex(ValueError, r'pic should be 2/3 dimensional. Got \d+ dimensions.'):
10021008
transforms.ToPILImage()(np.ones([1, 4, 4, 3]))
1009+
with self.assertRaisesRegex(ValueError, r'pic should not have > 4 channels. Got \d+ channels.'):
1010+
transforms.ToPILImage()(np.ones([4, 4, 6]))
10031011

10041012
@unittest.skipIf(stats is None, 'scipy.stats not available')
10051013
def test_random_vertical_flip(self):

torchvision/transforms/functional.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,10 @@ def to_pil_image(pic, mode=None):
183183
# if 2D image, add channel dimension (CHW)
184184
pic = pic.unsqueeze(0)
185185

186+
# check number of channels
187+
if pic.shape[-3] > 4:
188+
raise ValueError('pic should not have > 4 channels. Got {} channels.'.format(pic.shape[-3]))
189+
186190
elif isinstance(pic, np.ndarray):
187191
if pic.ndim not in {2, 3}:
188192
raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim))
@@ -191,6 +195,10 @@ def to_pil_image(pic, mode=None):
191195
# if 2D image, add channel dimension (HWC)
192196
pic = np.expand_dims(pic, 2)
193197

198+
# check number of channels
199+
if pic.shape[-1] > 4:
200+
raise ValueError('pic should not have > 4 channels. Got {} channels.'.format(pic.shape[-1]))
201+
194202
npimg = pic
195203
if isinstance(pic, torch.Tensor):
196204
if pic.is_floating_point() and mode != 'F':

0 commit comments

Comments
 (0)