@@ -59,12 +59,11 @@ def _test_functional_op(f, device, channels=3, fn_kwargs=None, test_exact_match=
59
59
_assert_approx_equal_tensor_to_pil (transformed_tensor , transformed_pil_img , ** match_kwargs )
60
60
61
61
62
- def _test_class_op (method , device , channels = 3 , meth_kwargs = None , test_exact_match = True , ** match_kwargs ):
63
- # TODO: change the name: it's not a method, it's a class.
62
+ def _test_class_op (transform_cls , device , channels = 3 , meth_kwargs = None , test_exact_match = True , ** match_kwargs ):
64
63
meth_kwargs = meth_kwargs or {}
65
64
66
65
# test for class interface
67
- f = method (** meth_kwargs )
66
+ f = transform_cls (** meth_kwargs )
68
67
scripted_fn = torch .jit .script (f )
69
68
70
69
tensor , pil_img = _create_data (26 , 34 , channels , device = device )
@@ -86,7 +85,7 @@ def _test_class_op(method, device, channels=3, meth_kwargs=None, test_exact_matc
86
85
_test_transform_vs_scripted_on_batch (f , scripted_fn , batch_tensors )
87
86
88
87
with get_tmp_dir () as tmp_dir :
89
- scripted_fn .save (os .path .join (tmp_dir , f"t_{ method .__name__ } .pt" ))
88
+ scripted_fn .save (os .path .join (tmp_dir , f"t_{ transform_cls .__name__ } .pt" ))
90
89
91
90
92
91
def _test_op (func , method , device , channels = 3 , fn_kwargs = None , meth_kwargs = None , test_exact_match = True , ** match_kwargs ):
0 commit comments