@@ -99,6 +99,120 @@ def test_pad(self):
99
99
"pad" , "Pad" , fn_kwargs = fn_kwargs , meth_kwargs = meth_kwargs
100
100
)
101
101
102
+ def test_crop (self ):
103
+ fn_kwargs = {"top" : 2 , "left" : 3 , "height" : 4 , "width" : 5 }
104
+ # Test transforms.RandomCrop with size and padding as tuple
105
+ meth_kwargs = {"size" : (4 , 5 ), "padding" : (4 , 4 ), "pad_if_needed" : True , }
106
+ self ._test_geom_op (
107
+ 'crop' , 'RandomCrop' , fn_kwargs = fn_kwargs , meth_kwargs = meth_kwargs
108
+ )
109
+
110
+ tensor = torch .randint (0 , 255 , (3 , 10 , 10 ), dtype = torch .uint8 )
111
+ # Test torchscript of transforms.RandomCrop with size as int
112
+ f = T .RandomCrop (size = 5 )
113
+ scripted_fn = torch .jit .script (f )
114
+ scripted_fn (tensor )
115
+
116
+ # Test torchscript of transforms.RandomCrop with size as [int, ]
117
+ f = T .RandomCrop (size = [5 , ], padding = [2 , ])
118
+ scripted_fn = torch .jit .script (f )
119
+ scripted_fn (tensor )
120
+
121
+ # Test torchscript of transforms.RandomCrop with size as list
122
+ f = T .RandomCrop (size = [6 , 6 ])
123
+ scripted_fn = torch .jit .script (f )
124
+ scripted_fn (tensor )
125
+
126
+ def test_center_crop (self ):
127
+ fn_kwargs = {"output_size" : (4 , 5 )}
128
+ meth_kwargs = {"size" : (4 , 5 ), }
129
+ self ._test_geom_op (
130
+ "center_crop" , "CenterCrop" , fn_kwargs = fn_kwargs , meth_kwargs = meth_kwargs
131
+ )
132
+ fn_kwargs = {"output_size" : (5 ,)}
133
+ meth_kwargs = {"size" : (5 , )}
134
+ self ._test_geom_op (
135
+ "center_crop" , "CenterCrop" , fn_kwargs = fn_kwargs , meth_kwargs = meth_kwargs
136
+ )
137
+ tensor = torch .randint (0 , 255 , (3 , 10 , 10 ), dtype = torch .uint8 )
138
+ # Test torchscript of transforms.CenterCrop with size as int
139
+ f = T .CenterCrop (size = 5 )
140
+ scripted_fn = torch .jit .script (f )
141
+ scripted_fn (tensor )
142
+
143
+ # Test torchscript of transforms.CenterCrop with size as [int, ]
144
+ f = T .CenterCrop (size = [5 , ])
145
+ scripted_fn = torch .jit .script (f )
146
+ scripted_fn (tensor )
147
+
148
+ # Test torchscript of transforms.CenterCrop with size as tuple
149
+ f = T .CenterCrop (size = (6 , 6 ))
150
+ scripted_fn = torch .jit .script (f )
151
+ scripted_fn (tensor )
152
+
153
+ def _test_geom_op_list_output (self , func , method , out_length , fn_kwargs = None , meth_kwargs = None ):
154
+ if fn_kwargs is None :
155
+ fn_kwargs = {}
156
+ if meth_kwargs is None :
157
+ meth_kwargs = {}
158
+ tensor , pil_img = self ._create_data (height = 20 , width = 20 )
159
+ transformed_t_list = getattr (F , func )(tensor , ** fn_kwargs )
160
+ transformed_p_list = getattr (F , func )(pil_img , ** fn_kwargs )
161
+ self .assertEqual (len (transformed_t_list ), len (transformed_p_list ))
162
+ self .assertEqual (len (transformed_t_list ), out_length )
163
+ for transformed_tensor , transformed_pil_img in zip (transformed_t_list , transformed_p_list ):
164
+ self .compareTensorToPIL (transformed_tensor , transformed_pil_img )
165
+
166
+ scripted_fn = torch .jit .script (getattr (F , func ))
167
+ transformed_t_list_script = scripted_fn (tensor .detach ().clone (), ** fn_kwargs )
168
+ self .assertEqual (len (transformed_t_list ), len (transformed_t_list_script ))
169
+ self .assertEqual (len (transformed_t_list_script ), out_length )
170
+ for transformed_tensor , transformed_tensor_script in zip (transformed_t_list , transformed_t_list_script ):
171
+ self .assertTrue (transformed_tensor .equal (transformed_tensor_script ),
172
+ msg = "{} vs {}" .format (transformed_tensor , transformed_tensor_script ))
173
+
174
+ # test for class interface
175
+ f = getattr (T , method )(** meth_kwargs )
176
+ scripted_fn = torch .jit .script (f )
177
+ output = scripted_fn (tensor )
178
+ self .assertEqual (len (output ), len (transformed_t_list_script ))
179
+
180
+ def test_five_crop (self ):
181
+ fn_kwargs = meth_kwargs = {"size" : (5 ,)}
182
+ self ._test_geom_op_list_output (
183
+ "five_crop" , "FiveCrop" , out_length = 5 , fn_kwargs = fn_kwargs , meth_kwargs = meth_kwargs
184
+ )
185
+ fn_kwargs = meth_kwargs = {"size" : [5 , ]}
186
+ self ._test_geom_op_list_output (
187
+ "five_crop" , "FiveCrop" , out_length = 5 , fn_kwargs = fn_kwargs , meth_kwargs = meth_kwargs
188
+ )
189
+ fn_kwargs = meth_kwargs = {"size" : (4 , 5 )}
190
+ self ._test_geom_op_list_output (
191
+ "five_crop" , "FiveCrop" , out_length = 5 , fn_kwargs = fn_kwargs , meth_kwargs = meth_kwargs
192
+ )
193
+ fn_kwargs = meth_kwargs = {"size" : [4 , 5 ]}
194
+ self ._test_geom_op_list_output (
195
+ "five_crop" , "FiveCrop" , out_length = 5 , fn_kwargs = fn_kwargs , meth_kwargs = meth_kwargs
196
+ )
197
+
198
+ def test_ten_crop (self ):
199
+ fn_kwargs = meth_kwargs = {"size" : (5 ,)}
200
+ self ._test_geom_op_list_output (
201
+ "ten_crop" , "TenCrop" , out_length = 10 , fn_kwargs = fn_kwargs , meth_kwargs = meth_kwargs
202
+ )
203
+ fn_kwargs = meth_kwargs = {"size" : [5 , ]}
204
+ self ._test_geom_op_list_output (
205
+ "ten_crop" , "TenCrop" , out_length = 10 , fn_kwargs = fn_kwargs , meth_kwargs = meth_kwargs
206
+ )
207
+ fn_kwargs = meth_kwargs = {"size" : (4 , 5 )}
208
+ self ._test_geom_op_list_output (
209
+ "ten_crop" , "TenCrop" , out_length = 10 , fn_kwargs = fn_kwargs , meth_kwargs = meth_kwargs
210
+ )
211
+ fn_kwargs = meth_kwargs = {"size" : [4 , 5 ]}
212
+ self ._test_geom_op_list_output (
213
+ "ten_crop" , "TenCrop" , out_length = 10 , fn_kwargs = fn_kwargs , meth_kwargs = meth_kwargs
214
+ )
215
+
102
216
103
217
if __name__ == '__main__' :
104
218
unittest .main ()
0 commit comments