Skip to content

Commit ff3f738

Browse files
yannadanialykhantejani
authored andcommitted
Normalize single images to make_grid (#360)
1 parent aafaa2a commit ff3f738

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

torchvision/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ def make_grid(tensor, nrow=8, padding=2,
3939
if tensor.dim() == 3: # single image
4040
if tensor.size(0) == 1: # if single-channel, convert to 3-channel
4141
tensor = torch.cat((tensor, tensor, tensor), 0)
42-
return tensor
42+
tensor = tensor.view(1, tensor.size(0), tensor.size(1), tensor.size(2))
43+
4344
if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images
4445
tensor = torch.cat((tensor, tensor, tensor), 1)
4546

@@ -65,6 +66,9 @@ def norm_range(t, range):
6566
else:
6667
norm_range(tensor, range)
6768

69+
if tensor.size(0) == 1:
70+
return tensor.squeeze()
71+
6872
# make the mini-batch of images into a grid
6973
nmaps = tensor.size(0)
7074
xmaps = min(nrow, nmaps)

0 commit comments

Comments
 (0)