diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 964504eb9dc..5871623d5aa 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -644,8 +644,7 @@ def get_params(img, scale, ratio): w = int(round(math.sqrt(target_area * aspect_ratio))) h = int(round(math.sqrt(target_area / aspect_ratio))) - - if w <= img.size[0] and h <= img.size[1]: + if w <= img.size[0] and h <= img.size[1] and aspect_ratio <= max(ratio) and aspect_ratio >= min(ratio): i = random.randint(0, img.size[1] - h) j = random.randint(0, img.size[0] - w) return i, j, h, w