@@ -687,40 +687,56 @@ def __repr__(self):
687
687
return self .__class__ .__name__ + '(p={})' .format (self .p )
688
688
689
689
690
- class RandomResizedCrop (object ):
691
- """Crop the given PIL Image to random size and aspect ratio.
690
+ class RandomResizedCrop (torch .nn .Module ):
691
+ """Crop the given image to random size and aspect ratio.
692
+ The image can be a PIL Image or a Tensor, in which case it is expected
693
+ to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
692
694
693
695
A crop of random size (default: of 0.08 to 1.0) of the original size and a random
694
696
aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
695
697
is finally resized to given size.
696
698
This is popularly used to train the Inception networks.
697
699
698
700
Args:
699
- size: expected output size of each edge
700
- scale: range of size of the origin size cropped
701
- ratio: range of aspect ratio of the origin aspect ratio cropped
702
- interpolation: Default: PIL.Image.BILINEAR
701
+ size (int or sequence): expected output size of each edge. If size is an
702
+ int instead of sequence like (h, w), a square output size ``(size, size)`` is
703
+ made. If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]).
704
+ scale (tuple of float): range of size of the origin size cropped
705
+ ratio (tuple of float): range of aspect ratio of the origin aspect ratio cropped.
706
+ interpolation (int): Desired interpolation. Default: ``PIL.Image.BILINEAR``
703
707
"""
704
708
705
709
def __init__ (self , size , scale = (0.08 , 1.0 ), ratio = (3. / 4. , 4. / 3. ), interpolation = Image .BILINEAR ):
706
- if isinstance (size , (tuple , list )):
707
- self .size = size
710
+ super ().__init__ ()
711
+ if isinstance (size , numbers .Number ):
712
+ self .size = (int (size ), int (size ))
713
+ elif isinstance (size , Sequence ) and len (size ) == 1 :
714
+ self .size = (size [0 ], size [0 ])
708
715
else :
709
- self .size = (size , size )
716
+ if len (size ) != 2 :
717
+ raise ValueError ("Please provide only two dimensions (h, w) for size." )
718
+ self .size = size
719
+
720
+ if not isinstance (scale , (tuple , list )):
721
+ raise TypeError ("Scale should be a sequence" )
722
+ if not isinstance (ratio , (tuple , list )):
723
+ raise TypeError ("Ratio should be a sequence" )
710
724
if (scale [0 ] > scale [1 ]) or (ratio [0 ] > ratio [1 ]):
711
- warnings .warn ("range should be of kind (min, max)" )
725
+ warnings .warn ("Scale and ratio should be of kind (min, max)" )
712
726
713
727
self .interpolation = interpolation
714
728
self .scale = scale
715
729
self .ratio = ratio
716
730
717
731
@staticmethod
718
- def get_params (img , scale , ratio ):
732
+ def get_params (
733
+ img : Tensor , scale : Tuple [float , float ], ratio : Tuple [float , float ]
734
+ ) -> Tuple [int , int , int , int ]:
719
735
"""Get parameters for ``crop`` for a random sized crop.
720
736
721
737
Args:
722
- img (PIL Image): Image to be cropped .
723
- scale (tuple): range of size of the origin size cropped
738
+ img (PIL Image or Tensor ): Input image .
739
+ scale (tuple): range of scale of the origin size cropped
724
740
ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
725
741
726
742
Returns:
@@ -731,24 +747,26 @@ def get_params(img, scale, ratio):
731
747
area = height * width
732
748
733
749
for _ in range (10 ):
734
- target_area = random .uniform (* scale ) * area
735
- log_ratio = (math .log (ratio [0 ]), math .log (ratio [1 ]))
736
- aspect_ratio = math .exp (random .uniform (* log_ratio ))
750
+ target_area = area * torch .empty (1 ).uniform_ (* scale ).item ()
751
+ log_ratio = torch .log (torch .tensor (ratio ))
752
+ aspect_ratio = torch .exp (
753
+ torch .empty (1 ).uniform_ (log_ratio [0 ], log_ratio [1 ])
754
+ ).item ()
737
755
738
756
w = int (round (math .sqrt (target_area * aspect_ratio )))
739
757
h = int (round (math .sqrt (target_area / aspect_ratio )))
740
758
741
759
if 0 < w <= width and 0 < h <= height :
742
- i = random .randint (0 , height - h )
743
- j = random .randint (0 , width - w )
760
+ i = torch .randint (0 , height - h + 1 , size = ( 1 ,)). item ( )
761
+ j = torch .randint (0 , width - w + 1 , size = ( 1 ,)). item ( )
744
762
return i , j , h , w
745
763
746
764
# Fallback to central crop
747
765
in_ratio = float (width ) / float (height )
748
- if ( in_ratio < min (ratio ) ):
766
+ if in_ratio < min (ratio ):
749
767
w = width
750
768
h = int (round (w / min (ratio )))
751
- elif ( in_ratio > max (ratio ) ):
769
+ elif in_ratio > max (ratio ):
752
770
h = height
753
771
w = int (round (h * max (ratio )))
754
772
else : # whole image
@@ -758,13 +776,13 @@ def get_params(img, scale, ratio):
758
776
j = (width - w ) // 2
759
777
return i , j , h , w
760
778
761
- def __call__ (self , img ):
779
+ def forward (self , img ):
762
780
"""
763
781
Args:
764
- img (PIL Image): Image to be cropped and resized.
782
+ img (PIL Image or Tensor ): Image to be cropped and resized.
765
783
766
784
Returns:
767
- PIL Image: Randomly cropped and resized image.
785
+ PIL Image or Tensor : Randomly cropped and resized image.
768
786
"""
769
787
i , j , h , w = self .get_params (img , self .scale , self .ratio )
770
788
return F .resized_crop (img , i , j , h , w , self .size , self .interpolation )
0 commit comments