From 659c854c6971ecc5b94dca3f4459ef2b7e42fb70 Mon Sep 17 00:00:00 2001 From: Elad Hoffer Date: Wed, 18 Jan 2017 23:28:12 +0200 Subject: [PATCH] fb.resnet like transforms --- torchvision/transforms.py | 116 +++++++++++++++++++++++++++++++++++++- 1 file changed, 113 insertions(+), 3 deletions(-) diff --git a/torchvision/transforms.py b/torchvision/transforms.py index 2be81389c2c..78c958afffb 100644 --- a/torchvision/transforms.py +++ b/torchvision/transforms.py @@ -7,6 +7,7 @@ import numbers import types + class Compose(object): """ Composes several transforms together. For example: @@ -15,6 +16,7 @@ class Compose(object): >>> transforms.ToTensor(), >>> ]) """ + def __init__(self, transforms): self.transforms = transforms @@ -27,44 +29,51 @@ def __call__(self, img): class ToTensor(object): """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """ + def __call__(self, pic): if isinstance(pic, np.ndarray): # handle numpy array img = torch.from_numpy(pic) else: # handle PIL Image - img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) + img = torch.ByteTensor( + torch.ByteStorage.from_buffer(pic.tobytes())) img = img.view(pic.size[1], pic.size[0], len(pic.mode)) # put it from HWC to CHW format # yikes, this transpose takes 80% of the loading time/CPU img = img.transpose(0, 1).transpose(0, 2).contiguous() return img.float().div(255) + class ToPILImage(object): """ Converts a torch.*Tensor of range [0, 1] and shape C x H x W or numpy ndarray of dtype=uint8, range[0, 255] and shape H x W x C to a PIL.Image of range [0, 255] """ + def __call__(self, pic): if isinstance(pic, np.ndarray): # handle numpy array img = Image.fromarray(pic) else: npimg = pic.mul(255).byte().numpy() - npimg = np.transpose(npimg, (1,2,0)) + npimg = np.transpose(npimg, (1, 2, 0)) img = Image.fromarray(npimg) return img + class Normalize(object): """ Given mean: (R, G, B) and std: (R, G, B), will normalize each channel of the torch.*Tensor, i.e. channel = (channel - mean) / std """ + def __init__(self, mean, std): self.mean = mean self.std = std - def __call__(self, tensor): + def __call__(self, img): + tensor = img.clone() # TODO: make efficient for t, m, s in zip(tensor, self.mean, self.std): t.sub_(m).div_(s) @@ -79,6 +88,7 @@ class Scale(object): size: size of the smaller edge interpolation: Default: PIL.Image.BILINEAR """ + def __init__(self, size, interpolation=Image.BILINEAR): self.size = size self.interpolation = interpolation @@ -102,6 +112,7 @@ class CenterCrop(object): the given size. size can be a tuple (target_height, target_width) or an integer, in which case the target will be of a square shape (size, size) """ + def __init__(self, size): if isinstance(size, numbers.Number): self.size = (int(size), int(size)) @@ -118,6 +129,7 @@ def __call__(self, img): class Pad(object): """Pads the given PIL.Image on all sides with the given "pad" value""" + def __init__(self, padding, fill=0): assert isinstance(padding, numbers.Number) assert isinstance(fill, numbers.Number) @@ -127,8 +139,10 @@ def __init__(self, padding, fill=0): def __call__(self, img): return ImageOps.expand(img, border=self.padding, fill=self.fill) + class Lambda(object): """Applies a lambda as a transform""" + def __init__(self, lambd): assert type(lambd) is types.LambdaType self.lambd = lambd @@ -142,6 +156,7 @@ class RandomCrop(object): the given size. size can be a tuple (target_height, target_width) or an integer, in which case the target will be of a square shape (size, size) """ + def __init__(self, size, padding=0): if isinstance(size, numbers.Number): self.size = (int(size), int(size)) @@ -166,6 +181,7 @@ def __call__(self, img): class RandomHorizontalFlip(object): """Randomly horizontally flips the given PIL.Image with a probability of 0.5 """ + def __call__(self, img): if random.random() < 0.5: return img.transpose(Image.FLIP_LEFT_RIGHT) @@ -179,6 +195,7 @@ class RandomSizedCrop(object): size: size of the smaller edge interpolation: Default: PIL.Image.BILINEAR """ + def __init__(self, size, interpolation=Image.BILINEAR): self.size = size self.interpolation = interpolation @@ -208,3 +225,96 @@ def __call__(self, img): scale = Scale(self.size, interpolation=self.interpolation) crop = CenterCrop(self.size) return crop(scale(img)) + + +class Lighting(object): + """Lighting noise(AlexNet - style PCA - based noise)""" + + def __init__(self, alphastd, eigval, eigvec): + self.alphastd = alphastd + self.eigval = eigval + self.eigvec = eigvec + + def __call__(self, img): + if self.alphastd == 0: + return img + + alpha = img.new().resize_(3).normal_(0, self.alphastd) + rgb = self.eigvec.type_as(img).clone()\ + .mul(alpha.view(1, 3).expand(3, 3))\ + .mul(self.eigval.view(1, 3).expand(3, 3))\ + .sum(1).squeeze() + + return img.add(rgb.view(3, 1, 1).expand_as(img)) + + +class Grayscale(object): + + def __call__(self, img): + gs = img.clone() + gs[0].mul_(0.299).add_(0.587, gs[1]).add_(0.114, gs[2]) + gs[1].copy_(gs[0]) + gs[2].copy_(gs[0]) + return gs + + +class Saturation(object): + + def __init__(self, var): + self.var = var + + def __call__(self, img): + gs = Grayscale()(img) + alpha = random.uniform(0, self.var) + return img.lerp(gs, alpha) + + +class Brightness(object): + + def __init__(self, var): + self.var = var + + def __call__(self, img): + gs = img.new().resize_as_(img).zero_() + alpha = random.uniform(0, self.var) + return img.lerp(gs, alpha) + + +class Contrast(object): + + def __init__(self, var): + self.var = var + + def __call__(self, img): + gs = Grayscale()(img) + gs.fill_(gs.mean()) + alpha = random.uniform(0, self.var) + return img.lerp(gs, alpha) + + +class RandomOrder(object): + """ Composes several transforms together in random order. + """ + + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, img): + if self.transforms is None: + return img + order = torch.randperm(len(self.transforms)) + for i in order: + img = self.transforms[i](img) + return img + + +class ColorJitter(RandomOrder): + + def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4): + self.transforms = [] + if brightness != 0: + self.transforms.append(Brightness(brightness)) + if contrast != 0: + self.transforms.append(Contrast(contrast)) + if saturation != 0: + self.transforms.append(Saturation(saturation))