Skip to content

fb.resnet like transforms #27

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 113 additions & 3 deletions torchvision/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numbers
import types


class Compose(object):
""" Composes several transforms together.
For example:
Expand All @@ -15,6 +16,7 @@ class Compose(object):
>>> transforms.ToTensor(),
>>> ])
"""

def __init__(self, transforms):
self.transforms = transforms

Expand All @@ -27,44 +29,51 @@ def __call__(self, img):
class ToTensor(object):
""" Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """

def __call__(self, pic):
if isinstance(pic, np.ndarray):
# handle numpy array
img = torch.from_numpy(pic)
else:
# handle PIL Image
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
img = torch.ByteTensor(
torch.ByteStorage.from_buffer(pic.tobytes()))
img = img.view(pic.size[1], pic.size[0], len(pic.mode))
# put it from HWC to CHW format
# yikes, this transpose takes 80% of the loading time/CPU
img = img.transpose(0, 1).transpose(0, 2).contiguous()
return img.float().div(255)


class ToPILImage(object):
""" Converts a torch.*Tensor of range [0, 1] and shape C x H x W
or numpy ndarray of dtype=uint8, range[0, 255] and shape H x W x C
to a PIL.Image of range [0, 255]
"""

def __call__(self, pic):
if isinstance(pic, np.ndarray):
# handle numpy array
img = Image.fromarray(pic)
else:
npimg = pic.mul(255).byte().numpy()
npimg = np.transpose(npimg, (1,2,0))
npimg = np.transpose(npimg, (1, 2, 0))
img = Image.fromarray(npimg)
return img


class Normalize(object):
""" Given mean: (R, G, B) and std: (R, G, B),
will normalize each channel of the torch.*Tensor, i.e.
channel = (channel - mean) / std
"""

def __init__(self, mean, std):
self.mean = mean
self.std = std

def __call__(self, tensor):
def __call__(self, img):
tensor = img.clone()
# TODO: make efficient
for t, m, s in zip(tensor, self.mean, self.std):
t.sub_(m).div_(s)
Expand All @@ -79,6 +88,7 @@ class Scale(object):
size: size of the smaller edge
interpolation: Default: PIL.Image.BILINEAR
"""

def __init__(self, size, interpolation=Image.BILINEAR):
self.size = size
self.interpolation = interpolation
Expand All @@ -102,6 +112,7 @@ class CenterCrop(object):
the given size. size can be a tuple (target_height, target_width)
or an integer, in which case the target will be of a square shape (size, size)
"""

def __init__(self, size):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
Expand All @@ -118,6 +129,7 @@ def __call__(self, img):

class Pad(object):
"""Pads the given PIL.Image on all sides with the given "pad" value"""

def __init__(self, padding, fill=0):
assert isinstance(padding, numbers.Number)
assert isinstance(fill, numbers.Number)
Expand All @@ -127,8 +139,10 @@ def __init__(self, padding, fill=0):
def __call__(self, img):
return ImageOps.expand(img, border=self.padding, fill=self.fill)


class Lambda(object):
"""Applies a lambda as a transform"""

def __init__(self, lambd):
assert type(lambd) is types.LambdaType
self.lambd = lambd
Expand All @@ -142,6 +156,7 @@ class RandomCrop(object):
the given size. size can be a tuple (target_height, target_width)
or an integer, in which case the target will be of a square shape (size, size)
"""

def __init__(self, size, padding=0):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
Expand All @@ -166,6 +181,7 @@ def __call__(self, img):
class RandomHorizontalFlip(object):
"""Randomly horizontally flips the given PIL.Image with a probability of 0.5
"""

def __call__(self, img):
if random.random() < 0.5:
return img.transpose(Image.FLIP_LEFT_RIGHT)
Expand All @@ -179,6 +195,7 @@ class RandomSizedCrop(object):
size: size of the smaller edge
interpolation: Default: PIL.Image.BILINEAR
"""

def __init__(self, size, interpolation=Image.BILINEAR):
self.size = size
self.interpolation = interpolation
Expand Down Expand Up @@ -208,3 +225,96 @@ def __call__(self, img):
scale = Scale(self.size, interpolation=self.interpolation)
crop = CenterCrop(self.size)
return crop(scale(img))


class Lighting(object):
"""Lighting noise(AlexNet - style PCA - based noise)"""

def __init__(self, alphastd, eigval, eigvec):
self.alphastd = alphastd
self.eigval = eigval
self.eigvec = eigvec

def __call__(self, img):
if self.alphastd == 0:
return img

alpha = img.new().resize_(3).normal_(0, self.alphastd)
rgb = self.eigvec.type_as(img).clone()\
.mul(alpha.view(1, 3).expand(3, 3))\
.mul(self.eigval.view(1, 3).expand(3, 3))\
.sum(1).squeeze()

return img.add(rgb.view(3, 1, 1).expand_as(img))


class Grayscale(object):

def __call__(self, img):
gs = img.clone()
gs[0].mul_(0.299).add_(0.587, gs[1]).add_(0.114, gs[2])
gs[1].copy_(gs[0])
gs[2].copy_(gs[0])
return gs


class Saturation(object):

def __init__(self, var):
self.var = var

def __call__(self, img):
gs = Grayscale()(img)
alpha = random.uniform(0, self.var)
return img.lerp(gs, alpha)


class Brightness(object):

def __init__(self, var):
self.var = var

def __call__(self, img):
gs = img.new().resize_as_(img).zero_()
alpha = random.uniform(0, self.var)
return img.lerp(gs, alpha)


class Contrast(object):

def __init__(self, var):
self.var = var

def __call__(self, img):
gs = Grayscale()(img)
gs.fill_(gs.mean())
alpha = random.uniform(0, self.var)
return img.lerp(gs, alpha)


class RandomOrder(object):
""" Composes several transforms together in random order.
"""

def __init__(self, transforms):
self.transforms = transforms

def __call__(self, img):
if self.transforms is None:
return img
order = torch.randperm(len(self.transforms))
for i in order:
img = self.transforms[i](img)
return img


class ColorJitter(RandomOrder):

def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4):
self.transforms = []
if brightness != 0:
self.transforms.append(Brightness(brightness))
if contrast != 0:
self.transforms.append(Contrast(contrast))
if saturation != 0:
self.transforms.append(Saturation(saturation))