Skip to content

add support for single channel images #174

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 27 additions & 8 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import torchvision.transforms as transforms
import unittest
from parameterized import parameterized

This comment was marked as off-topic.

This comment was marked as off-topic.

import random
import numpy as np
from PIL import Image
Expand Down Expand Up @@ -171,17 +172,35 @@ def test_lambda(self):
y = trans(x)
assert (y.equal(x))

def test_to_tensor(self):
channels = 3
height, width = 4, 4
trans = transforms.ToTensor()
input_data = torch.ByteTensor(channels, height, width).random_(0, 255).float().div_(255)
@parameterized.expand([
('3channel', 3, 4, 4),
('1channel', 1, 4, 4),
])
def test_pil_to_tensor(self, _, channels, height, width):

This comment was marked as off-topic.

input_data = torch.ByteTensor(channels, height, width)
input_data = input_data.random_(0, 255).float().div_(255)
img = transforms.ToPILImage()(input_data)
output = trans(img)
output = transforms.ToTensor()(img)
assert np.allclose(input_data.numpy(), output.numpy())

ndarray = np.random.randint(low=0, high=255, size=(height, width, channels))
output = trans(ndarray)
@parameterized.expand([
('smoke', 4, 4),
])
def test_ndarray_to_tensor_2dim(self, _, height, width):
ndarray_size = (height, width)
ndarray = np.random.randint(low=0, high=255, size=ndarray_size)
output = transforms.ToTensor()(ndarray)
expected_output = ndarray[..., np.newaxis].transpose((2, 0, 1)) / 255.0
assert np.allclose(output.numpy(), expected_output)

@parameterized.expand([
('1channel', 1, 4, 4),
('3channel', 3, 4, 4),
])
def test_ndarray_to_tensor_3dim(self, _, channels, height, width):
ndarray_size = (height, width, channels)
ndarray = np.random.randint(low=0, high=255, size=ndarray_size)
output = transforms.ToTensor()(ndarray)
expected_output = ndarray.transpose((2, 0, 1)) / 255.0
assert np.allclose(output.numpy(), expected_output)

Expand Down
129 changes: 68 additions & 61 deletions torchvision/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,33 +52,38 @@ def __call__(self, pic):
"""
if isinstance(pic, np.ndarray):
# handle numpy array
img = torch.from_numpy(pic.transpose((2, 0, 1)))
if pic.ndim == 2:
pic = pic[np.newaxis]
elif pic.ndim == 3:
pic.transpose((2, 0, 1))
else:
raise ValueError('only 2D and 3D images accepted, got {}D image'.format(pic.ndim)
# backward compatibility
return img.float().div(255)
return torch.from_numpy(pic).float().div(255.)

if accimage is not None and isinstance(pic, accimage.Image):
nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
nppic=np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)

This comment was marked as off-topic.

pic.copyto(nppic)
return torch.from_numpy(nppic)

# handle PIL Image
if pic.mode == 'I':
img = torch.from_numpy(np.array(pic, np.int32, copy=False))
img=torch.from_numpy(np.array(pic, np.int32, copy=False))
elif pic.mode == 'I;16':
img = torch.from_numpy(np.array(pic, np.int16, copy=False))
img=torch.from_numpy(np.array(pic, np.int16, copy=False))
else:
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
img=torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
if pic.mode == 'YCbCr':
nchannel = 3
nchannel=3
elif pic.mode == 'I;16':
nchannel = 1
nchannel=1
else:
nchannel = len(pic.mode)
img = img.view(pic.size[1], pic.size[0], nchannel)
nchannel=len(pic.mode)
img=img.view(pic.size[1], pic.size[0], nchannel)
# put it from HWC to CHW format
# yikes, this transpose takes 80% of the loading time/CPU
img = img.transpose(0, 1).transpose(0, 2).contiguous()
img=img.transpose(0, 1).transpose(0, 2).contiguous()
if isinstance(img, torch.ByteTensor):
return img.float().div(255)
else:
Expand All @@ -101,30 +106,32 @@ def __call__(self, pic):
PIL.Image: Image converted to PIL.Image.

"""
npimg = pic
mode = None
npimg=pic
mode=None
if isinstance(pic, torch.FloatTensor):
pic = pic.mul(255).byte()
pic=pic.mul(255).byte()
if torch.is_tensor(pic):
npimg = np.transpose(pic.numpy(), (1, 2, 0))
npimg=np.transpose(pic.numpy(), (1, 2, 0))
assert isinstance(npimg, np.ndarray), 'pic should be Tensor or ndarray'
if len(npimg.shape) < 3:

This comment was marked as off-topic.

npimg=np.reshape(npimg, npimg.shape + (1,))
if npimg.shape[2] == 1:
npimg = npimg[:, :, 0]
npimg=npimg[:, :, 0]

if npimg.dtype == np.uint8:
mode = 'L'
mode='L'
if npimg.dtype == np.int16:
mode = 'I;16'
mode='I;16'
if npimg.dtype == np.int32:
mode = 'I'
mode='I'
elif npimg.dtype == np.float32:
mode = 'F'
mode='F'
elif npimg.shape[2] == 4:
if npimg.dtype == np.uint8:
mode = 'RGBA'
mode='RGBA'
else:
if npimg.dtype == np.uint8:
mode = 'RGB'
mode='RGB'
assert mode is not None, '{} is not supported'.format(npimg.dtype)
return Image.fromarray(npimg, mode=mode)

Expand All @@ -143,8 +150,8 @@ class Normalize(object):
"""

def __init__(self, mean, std):
self.mean = mean
self.std = std
self.mean=mean
self.std=std

def __call__(self, tensor):
"""
Expand Down Expand Up @@ -175,8 +182,8 @@ class Scale(object):

def __init__(self, size, interpolation=Image.BILINEAR):
assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
self.size = size
self.interpolation = interpolation
self.size=size
self.interpolation=interpolation

def __call__(self, img):
"""
Expand All @@ -187,16 +194,16 @@ def __call__(self, img):
PIL.Image: Rescaled image.
"""
if isinstance(self.size, int):
w, h = img.size
w, h=img.size
if (w <= h and w == self.size) or (h <= w and h == self.size):
return img
if w < h:
ow = self.size
oh = int(self.size * h / w)
ow=self.size
oh=int(self.size * h / w)
return img.resize((ow, oh), self.interpolation)
else:
oh = self.size
ow = int(self.size * w / h)
oh=self.size
ow=int(self.size * w / h)
return img.resize((ow, oh), self.interpolation)
else:
return img.resize(self.size[::-1], self.interpolation)
Expand All @@ -213,9 +220,9 @@ class CenterCrop(object):

def __init__(self, size):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
self.size=(int(size), int(size))
else:
self.size = size
self.size=size

def __call__(self, img):
"""
Expand All @@ -225,10 +232,10 @@ def __call__(self, img):
Returns:
PIL.Image: Cropped image.
"""
w, h = img.size
th, tw = self.size
x1 = int(round((w - tw) / 2.))
y1 = int(round((h - th) / 2.))
w, h=img.size
th, tw=self.size
x1=int(round((w - tw) / 2.))
y1=int(round((h - th) / 2.))
return img.crop((x1, y1, x1 + tw, y1 + th))


Expand All @@ -252,8 +259,8 @@ def __init__(self, padding, fill=0):
raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
"{} element tuple".format(len(padding)))

self.padding = padding
self.fill = fill
self.padding=padding
self.fill=fill

def __call__(self, img):
"""
Expand All @@ -275,7 +282,7 @@ class Lambda(object):

def __init__(self, lambd):
assert isinstance(lambd, types.LambdaType)
self.lambd = lambd
self.lambd=lambd

def __call__(self, img):
return self.lambd(img)
Expand All @@ -296,10 +303,10 @@ class RandomCrop(object):

def __init__(self, size, padding=0):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
self.size=(int(size), int(size))
else:
self.size = size
self.padding = padding
self.size=size
self.padding=padding

def __call__(self, img):
"""
Expand All @@ -310,15 +317,15 @@ def __call__(self, img):
PIL.Image: Cropped image.
"""
if self.padding > 0:
img = ImageOps.expand(img, border=self.padding, fill=0)
img=ImageOps.expand(img, border=self.padding, fill=0)

w, h = img.size
th, tw = self.size
w, h=img.size
th, tw=self.size
if w == tw and h == th:
return img

x1 = random.randint(0, w - tw)
y1 = random.randint(0, h - th)
x1=random.randint(0, w - tw)
y1=random.randint(0, h - th)
return img.crop((x1, y1, x1 + tw, y1 + th))


Expand Down Expand Up @@ -352,31 +359,31 @@ class RandomSizedCrop(object):
"""

def __init__(self, size, interpolation=Image.BILINEAR):
self.size = size
self.interpolation = interpolation
self.size=size
self.interpolation=interpolation

def __call__(self, img):
for attempt in range(10):
area = img.size[0] * img.size[1]
target_area = random.uniform(0.08, 1.0) * area
aspect_ratio = random.uniform(3. / 4, 4. / 3)
area=img.size[0] * img.size[1]
target_area=random.uniform(0.08, 1.0) * area
aspect_ratio=random.uniform(3. / 4, 4. / 3)

w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
w=int(round(math.sqrt(target_area * aspect_ratio)))
h=int(round(math.sqrt(target_area / aspect_ratio)))

if random.random() < 0.5:
w, h = h, w
w, h=h, w

if w <= img.size[0] and h <= img.size[1]:
x1 = random.randint(0, img.size[0] - w)
y1 = random.randint(0, img.size[1] - h)
x1=random.randint(0, img.size[0] - w)
y1=random.randint(0, img.size[1] - h)

img = img.crop((x1, y1, x1 + w, y1 + h))
img=img.crop((x1, y1, x1 + w, y1 + h))
assert(img.size == (w, h))

return img.resize((self.size, self.size), self.interpolation)

# Fallback
scale = Scale(self.size, interpolation=self.interpolation)
crop = CenterCrop(self.size)
scale=Scale(self.size, interpolation=self.interpolation)
crop=CenterCrop(self.size)
return crop(scale(img))