@@ -1102,68 +1102,79 @@ def __repr__(self):
1102
1102
return format_string
1103
1103
1104
1104
1105
- class RandomRotation (object ):
1105
+ class RandomRotation (torch . nn . Module ):
1106
1106
"""Rotate the image by angle.
1107
+ The image can be a PIL Image or a Tensor, in which case it is expected
1108
+ to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
1107
1109
1108
1110
Args:
1109
1111
degrees (sequence or float or int): Range of degrees to select from.
1110
1112
If degrees is a number instead of sequence like (min, max), the range of degrees
1111
1113
will be (-degrees, +degrees).
1112
- resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional):
1113
- An optional resampling filter. See `filters`_ for more information.
1114
+ resample (int, optional): An optional resampling filter. See `filters`_ for more information.
1114
1115
If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST.
1116
+ If input is Tensor, only ``PIL.Image.NEAREST`` and ``PIL.Image.BILINEAR`` are supported.
1115
1117
expand (bool, optional): Optional expansion flag.
1116
1118
If true, expands the output to make it large enough to hold the entire rotated image.
1117
1119
If false or omitted, make the output image the same size as the input image.
1118
1120
Note that the expand flag assumes rotation around the center and no translation.
1119
- center (2-tuple, optional): Optional center of rotation.
1120
- Origin is the upper left corner.
1121
+ center (list or tuple, optional): Optional center of rotation, (x, y). Origin is the upper left corner.
1121
1122
Default is the center of the image.
1122
1123
fill (n-tuple or int or float): Pixel fill value for area outside the rotated
1123
1124
image. If int or float, the value is used for all bands respectively.
1124
- Defaults to 0 for all bands. This option is only available for ``pillow>=5.2.0``.
1125
+ Defaults to 0 for all bands. This option is only available for Pillow>=5.2.0.
1126
+ This option is not supported for Tensor input. Fill value for the area outside the transform in the output
1127
+ image is always 0.
1125
1128
1126
1129
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
1127
1130
1128
1131
"""
1129
1132
1130
1133
def __init__ (self , degrees , resample = False , expand = False , center = None , fill = None ):
1134
+ super ().__init__ ()
1131
1135
if isinstance (degrees , numbers .Number ):
1132
1136
if degrees < 0 :
1133
1137
raise ValueError ("If degrees is a single number, it must be positive." )
1134
- self . degrees = ( - degrees , degrees )
1138
+ degrees = [ - degrees , degrees ]
1135
1139
else :
1140
+ if not isinstance (degrees , Sequence ):
1141
+ raise TypeError ("degrees should be a sequence of length 2." )
1136
1142
if len (degrees ) != 2 :
1137
1143
raise ValueError ("If degrees is a sequence, it must be of len 2." )
1138
- self .degrees = degrees
1144
+
1145
+ self .degrees = [float (d ) for d in degrees ]
1146
+
1147
+ if center is not None :
1148
+ if not isinstance (center , Sequence ):
1149
+ raise TypeError ("center should be a sequence of length 2." )
1150
+ if len (center ) != 2 :
1151
+ raise ValueError ("center should be a sequence of length 2." )
1152
+
1153
+ self .center = center
1139
1154
1140
1155
self .resample = resample
1141
1156
self .expand = expand
1142
- self .center = center
1143
1157
self .fill = fill
1144
1158
1145
1159
@staticmethod
1146
- def get_params (degrees ) :
1160
+ def get_params (degrees : List [ float ]) -> float :
1147
1161
"""Get parameters for ``rotate`` for a random rotation.
1148
1162
1149
1163
Returns:
1150
- sequence: params to be passed to ``rotate`` for random rotation.
1164
+ float: angle parameter to be passed to ``rotate`` for random rotation.
1151
1165
"""
1152
- angle = random .uniform (degrees [0 ], degrees [1 ])
1153
-
1166
+ angle = float (torch .empty (1 ).uniform_ (float (degrees [0 ]), float (degrees [1 ])).item ())
1154
1167
return angle
1155
1168
1156
- def __call__ (self , img ):
1169
+ def forward (self , img ):
1157
1170
"""
1158
1171
Args:
1159
- img (PIL Image): Image to be rotated.
1172
+ img (PIL Image or Tensor ): Image to be rotated.
1160
1173
1161
1174
Returns:
1162
- PIL Image: Rotated image.
1175
+ PIL Image or Tensor : Rotated image.
1163
1176
"""
1164
-
1165
1177
angle = self .get_params (self .degrees )
1166
-
1167
1178
return F .rotate (img , angle , self .resample , self .expand , self .center , self .fill )
1168
1179
1169
1180
def __repr__ (self ):
0 commit comments