@@ -687,40 +687,56 @@ def __repr__(self):
687687 return self .__class__ .__name__ + '(p={})' .format (self .p )
688688
689689
690- class RandomResizedCrop (object ):
691- """Crop the given PIL Image to random size and aspect ratio.
690+ class RandomResizedCrop (torch .nn .Module ):
691+ """Crop the given image to random size and aspect ratio.
692+ The image can be a PIL Image or a Tensor, in which case it is expected
693+ to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
692694
693695 A crop of random size (default: of 0.08 to 1.0) of the original size and a random
694696 aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
695697 is finally resized to given size.
696698 This is popularly used to train the Inception networks.
697699
698700 Args:
699- size: expected output size of each edge
700- scale: range of size of the origin size cropped
701- ratio: range of aspect ratio of the origin aspect ratio cropped
702- interpolation: Default: PIL.Image.BILINEAR
701+ size (int or sequence): expected output size of each edge. If size is an
702+ int instead of sequence like (h, w), a square output size ``(size, size)`` is
703+ made. If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]).
704+ scale (tuple of float): range of size of the origin size cropped
705+ ratio (tuple of float): range of aspect ratio of the origin aspect ratio cropped.
706+ interpolation (int): Desired interpolation. Default: ``PIL.Image.BILINEAR``
703707 """
704708
705709 def __init__ (self , size , scale = (0.08 , 1.0 ), ratio = (3. / 4. , 4. / 3. ), interpolation = Image .BILINEAR ):
706- if isinstance (size , (tuple , list )):
707- self .size = size
710+ super ().__init__ ()
711+ if isinstance (size , numbers .Number ):
712+ self .size = (int (size ), int (size ))
713+ elif isinstance (size , Sequence ) and len (size ) == 1 :
714+ self .size = (size [0 ], size [0 ])
708715 else :
709- self .size = (size , size )
716+ if len (size ) != 2 :
717+ raise ValueError ("Please provide only two dimensions (h, w) for size." )
718+ self .size = size
719+
720+ if not isinstance (scale , (tuple , list )):
721+ raise TypeError ("Scale should be a sequence" )
722+ if not isinstance (ratio , (tuple , list )):
723+ raise TypeError ("Ratio should be a sequence" )
710724 if (scale [0 ] > scale [1 ]) or (ratio [0 ] > ratio [1 ]):
711- warnings .warn ("range should be of kind (min, max)" )
725+ warnings .warn ("Scale and ratio should be of kind (min, max)" )
712726
713727 self .interpolation = interpolation
714728 self .scale = scale
715729 self .ratio = ratio
716730
717731 @staticmethod
718- def get_params (img , scale , ratio ):
732+ def get_params (
733+ img : Tensor , scale : Tuple [float , float ], ratio : Tuple [float , float ]
734+ ) -> Tuple [int , int , int , int ]:
719735 """Get parameters for ``crop`` for a random sized crop.
720736
721737 Args:
722- img (PIL Image): Image to be cropped .
723- scale (tuple): range of size of the origin size cropped
738+ img (PIL Image or Tensor ): Input image .
739+ scale (tuple): range of scale of the origin size cropped
724740 ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
725741
726742 Returns:
@@ -731,24 +747,26 @@ def get_params(img, scale, ratio):
731747 area = height * width
732748
733749 for _ in range (10 ):
734- target_area = random .uniform (* scale ) * area
735- log_ratio = (math .log (ratio [0 ]), math .log (ratio [1 ]))
736- aspect_ratio = math .exp (random .uniform (* log_ratio ))
750+ target_area = area * torch .empty (1 ).uniform_ (* scale ).item ()
751+ log_ratio = torch .log (torch .tensor (ratio ))
752+ aspect_ratio = torch .exp (
753+ torch .empty (1 ).uniform_ (log_ratio [0 ], log_ratio [1 ])
754+ ).item ()
737755
738756 w = int (round (math .sqrt (target_area * aspect_ratio )))
739757 h = int (round (math .sqrt (target_area / aspect_ratio )))
740758
741759 if 0 < w <= width and 0 < h <= height :
742- i = random .randint (0 , height - h )
743- j = random .randint (0 , width - w )
760+ i = torch .randint (0 , height - h + 1 , size = ( 1 ,)). item ( )
761+ j = torch .randint (0 , width - w + 1 , size = ( 1 ,)). item ( )
744762 return i , j , h , w
745763
746764 # Fallback to central crop
747765 in_ratio = float (width ) / float (height )
748- if ( in_ratio < min (ratio ) ):
766+ if in_ratio < min (ratio ):
749767 w = width
750768 h = int (round (w / min (ratio )))
751- elif ( in_ratio > max (ratio ) ):
769+ elif in_ratio > max (ratio ):
752770 h = height
753771 w = int (round (h * max (ratio )))
754772 else : # whole image
@@ -758,13 +776,13 @@ def get_params(img, scale, ratio):
758776 j = (width - w ) // 2
759777 return i , j , h , w
760778
761- def __call__ (self , img ):
779+ def forward (self , img ):
762780 """
763781 Args:
764- img (PIL Image): Image to be cropped and resized.
782+ img (PIL Image or Tensor ): Image to be cropped and resized.
765783
766784 Returns:
767- PIL Image: Randomly cropped and resized image.
785+ PIL Image or Tensor : Randomly cropped and resized image.
768786 """
769787 i , j , h , w = self .get_params (img , self .scale , self .ratio )
770788 return F .resized_crop (img , i , j , h , w , self .size , self .interpolation )
0 commit comments