diff --git a/torchvision/transforms.py b/torchvision/transforms.py index df1bd6b55d9..b5f4306f91c 100644 --- a/torchvision/transforms.py +++ b/torchvision/transforms.py @@ -140,6 +140,25 @@ def __call__(self, img): return img.resize((ow, oh), self.interpolation) +class Crop(object): + """Crops the given PIL.Image according to the given top left (x1, y1) and + bottom right (x2, y2) coordinates. The coordinates are 0-indexed. + """ + + def __init__(self, x1, y1, x2, y2): + self.x1 = x1 + self.y1 = y1 + self.x2 = x2 + self.y2 = y2 + + def __call__(self, img): + w, h = img.size + assert(self.x1 < w and self.x1 >= 0) + assert(self.y1 < h and self.y1 >= 0) + assert(self.x2 < w and self.x2 >= 0) + assert(self.y2 < h and self.y2 >= 0) + return img.crop((self.x1, self.y1, self.x2, self.y2)) + class CenterCrop(object): """Crops the given PIL.Image at the center to have a region of the given size. size can be a tuple (target_height, target_width)