diff --git a/tests/test_models_unet.py b/tests/test_models_unet.py index 80055c1a10f8..94a186d1c06a 100644 --- a/tests/test_models_unet.py +++ b/tests/test_models_unet.py @@ -199,7 +199,7 @@ def prepare_init_args_and_inputs_for_common(self): def test_gradient_checkpointing(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) + model = self.model_class(**init_dict).eval() model.to(torch_device) out = model(**inputs_dict).sample