diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 453690c1c901..5087bd0094a5 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1744,6 +1744,10 @@ def test_push_to_hub_library_name(self): delete_repo(self.repo_id, token=TOKEN) +@require_torch_gpu +@require_torch_2 +@is_torch_compile +@slow class TorchCompileTesterMixin: def setUp(self): # clean up the VRAM before each test @@ -1759,12 +1763,7 @@ def tearDown(self): gc.collect() backend_empty_cache(torch_device) - @require_torch_gpu - @require_torch_2 - @is_torch_compile - @slow def test_torch_compile_recompilation_and_graph_break(self): - torch.compiler.reset() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict).to(torch_device) @@ -1778,6 +1777,31 @@ def test_torch_compile_recompilation_and_graph_break(self): _ = model(**inputs_dict) _ = model(**inputs_dict) + def test_compile_with_group_offloading(self): + torch._dynamo.config.cache_size_limit = 10000 + + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + + if not getattr(model, "_supports_group_offloading", True): + return + + model.eval() + # TODO: Can test for other group offloading kwargs later if needed. + group_offload_kwargs = { + "onload_device": "cuda", + "offload_device": "cpu", + "offload_type": "block_level", + "num_blocks_per_group": 1, + "use_stream": True, + "non_blocking": True, + } + model.enable_group_offload(**group_offload_kwargs) + model.compile() + with torch.no_grad(): + _ = model(**inputs_dict) + _ = model(**inputs_dict) + @slow @require_torch_2