Skip to content

Commit 293bfae

Browse files
authored
Merge pull request #6 from pytorch/utils
adding utils to save image to disk, and to create grid of images
2 parents 4b4795f + 7acde28 commit 293bfae

File tree

2 files changed

+56
-7
lines changed

2 files changed

+56
-7
lines changed

README.md

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ This repository consists of:
44

55
- [vision.datasets](#datasets) : Data loaders for popular vision datasets
66
- [vision.transforms](#transforms) : Common image transformations such as random crop, rotations etc.
7+
- [vision.utils](#utils) : Useful stuff such as saving tensor (3 x H x W) as image to disk, given a mini-batch creating a grid of images, etc.
78
- `[WIP] vision.models` : Model definitions and Pre-trained models for popular models such as AlexNet, VGG, ResNet etc.
89

910
# Installation
@@ -60,8 +61,8 @@ Example:
6061
```python
6162
import torchvision.datasets as dset
6263
import torchvision.transforms as transforms
63-
cap = dset.CocoCaptions(root = 'dir where images are',
64-
annFile = 'json annotation file',
64+
cap = dset.CocoCaptions(root = 'dir where images are',
65+
annFile = 'json annotation file',
6566
transform=transforms.ToTensor())
6667

6768
print('Number of samples: ', len(cap))
@@ -76,10 +77,10 @@ Output:
7677
```
7778
Number of samples: 82783
7879
Image Size: (3L, 427L, 640L)
79-
[u'A plane emitting smoke stream flying over a mountain.',
80-
u'A plane darts across a bright blue sky behind a mountain covered in snow',
81-
u'A plane leaves a contrail above the snowy mountain top.',
82-
u'A mountain that has a plane flying overheard in the distance.',
80+
[u'A plane emitting smoke stream flying over a mountain.',
81+
u'A plane darts across a bright blue sky behind a mountain covered in snow',
82+
u'A plane leaves a contrail above the snowy mountain top.',
83+
u'A mountain that has a plane flying overheard in the distance.',
8384
u'A mountain view with a plume of smoke in the background']
8485
```
8586

@@ -174,7 +175,7 @@ rescaled to (size * height / width, size)
174175
Crops the given PIL.Image at the center to have a region of
175176
the given size. size can be a tuple (target_height, target_width)
176177
or an integer, in which case the target will be of a square shape (size, size)
177-
178+
178179
### `RandomCrop(size)`
179180
Crops the given PIL.Image at a random location to have a region of
180181
the given size. size can be a tuple (target_height, target_width)
@@ -200,3 +201,13 @@ Given mean: (R, G, B) and std: (R, G, B), will normalize each channel of the tor
200201
- `ToTensor()` - Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
201202
- `ToPILImage()` - Converts a torch.*Tensor of range [0, 1] and shape C x H x W or numpy ndarray of dtype=uint8, range[0, 255] and shape H x W x C to a PIL.Image of range [0, 255]
202203

204+
205+
# Utils
206+
207+
### make_grid(tensor, nrow=8, padding=2)
208+
Given a 4D mini-batch Tensor of shape (B x C x H x W), makes a grid of images
209+
210+
### save_image(tensor, filename, nrow=8, padding=2)
211+
Saves a given Tensor into an image file.
212+
213+
If given a mini-batch tensor, will save the tensor as a grid of images.

torchvision/utils.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
2+
def make_grid(tensor, nrow=8, padding=2):
3+
"""
4+
Given a 4D mini-batch Tensor of shape (B x C x H x W),
5+
makes a grid of images
6+
"""
7+
import math
8+
if tensor.dim() == 3: # single image
9+
return tensor
10+
# make the mini-batch of images into a grid
11+
nmaps = tensor.size(0)
12+
xmaps = min(nrow, nmaps)
13+
ymaps = int(math.ceil(nmaps / xmaps))
14+
height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding)
15+
grid = tensor.new(3, height * ymaps, width * xmaps).fill_(tensor.max())
16+
k = 0
17+
for y in range(ymaps):
18+
for x in range(xmaps):
19+
if k >= nmaps:
20+
break
21+
grid.narrow(1, y*height+1+padding/2,height-padding)\
22+
.narrow(2, x*width+1+padding/2, width-padding)\
23+
.copy_(tensor[k])
24+
k = k + 1
25+
return grid
26+
27+
28+
def save_image(tensor, filename, nrow=8, padding=2):
29+
"""
30+
Saves a given Tensor into an image file.
31+
If given a mini-batch tensor, will save the tensor as a grid of images.
32+
"""
33+
from PIL import Image
34+
tensor = tensor.cpu()
35+
grid = make_grid(tensor, nrow=nrow, padding=padding)
36+
ndarr = grid.mul(0.5).add(0.5).mul(255).byte().transpose(0,2).transpose(0,1).numpy()
37+
im = Image.fromarray(ndarr)
38+
im.save(filename)

0 commit comments

Comments
 (0)