From bc83cb89b08934b2cc5a074651ae353b9576fc90 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 23 Jun 2025 14:25:37 +0530 Subject: [PATCH 1/4] add test for checking compile on different shapes. --- tests/models/test_modeling_common.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index e8b41ddbfd87..442a25eff737 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1961,6 +1961,22 @@ def test_compile_with_group_offloading(self): _ = model(**inputs_dict) _ = model(**inputs_dict) + def test_compile_on_different_shapes(self): + torch.fx.experimental._config.use_duck_shape = False + + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + model = torch.compile(model, fullgraph=True, dynamic=True) + + with ( + torch._inductor.utils.fresh_inductor_cache(), + torch._dynamo.config.patch(error_on_recompile=True), + torch.no_grad(), + ): + print(f"{inputs_dict.keys()=}") + out = model(**inputs_dict) + assert out is None + @slow @require_torch_2 From 934624bac1722ecb3c7be27dd6efe4dbe08d41e5 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 23 Jun 2025 15:02:33 +0530 Subject: [PATCH 2/4] update --- tests/models/test_modeling_common.py | 20 +++++++++------- .../test_models_transformer_flux.py | 24 ++++++++++++------- 2 files changed, 26 insertions(+), 18 deletions(-) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 442a25eff737..e9022d128c7a 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -76,6 +76,7 @@ require_torch_accelerator_with_training, require_torch_gpu, require_torch_multi_accelerator, + require_torch_version_greater, run_test_in_subprocess, slow, torch_all_close, @@ -1908,6 +1909,8 @@ def test_push_to_hub_library_name(self): @is_torch_compile @slow class TorchCompileTesterMixin: + different_shapes_for_compilation = None + def setUp(self): # clean up the VRAM before each test super().setUp() @@ -1961,21 +1964,20 @@ def test_compile_with_group_offloading(self): _ = model(**inputs_dict) _ = model(**inputs_dict) + @require_torch_version_greater("2.7.1") def test_compile_on_different_shapes(self): + if self.different_shapes_for_compilation is None: + pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.") torch.fx.experimental._config.use_duck_shape = False - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict, _ = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict).to(torch_device) model = torch.compile(model, fullgraph=True, dynamic=True) - with ( - torch._inductor.utils.fresh_inductor_cache(), - torch._dynamo.config.patch(error_on_recompile=True), - torch.no_grad(), - ): - print(f"{inputs_dict.keys()=}") - out = model(**inputs_dict) - assert out is None + for height, width in self.different_shapes_for_compilation: + with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad(): + inputs_dict = self.prepare_dummy_input(height=height, width=width) + _ = model(**inputs_dict) @slow diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py index 0a55236ef1c7..4552b2e1f5cf 100644 --- a/tests/models/transformers/test_models_transformer_flux.py +++ b/tests/models/transformers/test_models_transformer_flux.py @@ -91,10 +91,20 @@ class FluxTransformerTests(ModelTesterMixin, unittest.TestCase): @property def dummy_input(self): + return self.prepare_dummy_input() + + @property + def input_shape(self): + return (16, 4) + + @property + def output_shape(self): + return (16, 4) + + def prepare_dummy_input(self, height=4, width=4): batch_size = 1 num_latent_channels = 4 num_image_channels = 3 - height = width = 4 sequence_length = 48 embedding_dim = 32 @@ -114,14 +124,6 @@ def dummy_input(self): "timestep": timestep, } - @property - def input_shape(self): - return (16, 4) - - @property - def output_shape(self): - return (16, 4) - def prepare_init_args_and_inputs_for_common(self): init_dict = { "patch_size": 1, @@ -173,10 +175,14 @@ def test_gradient_checkpointing_is_applied(self): class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): model_class = FluxTransformer2DModel + different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)] def prepare_init_args_and_inputs_for_common(self): return FluxTransformerTests().prepare_init_args_and_inputs_for_common() + def prepare_dummy_input(self, height, width): + return FluxTransformerTests().prepare_dummy_input(height=height, width=width) + class FluxTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase): model_class = FluxTransformer2DModel From b761e8c49b1d0bd0d47b142d5400fd651f14a919 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 24 Jun 2025 09:12:53 +0530 Subject: [PATCH 3/4] update --- docs/source/en/optimization/fp16.md | 17 +++++++++++++++++ tests/models/test_modeling_common.py | 6 +++--- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/docs/source/en/optimization/fp16.md b/docs/source/en/optimization/fp16.md index 2e12bfadcf5c..2e1fd80d7e2d 100644 --- a/docs/source/en/optimization/fp16.md +++ b/docs/source/en/optimization/fp16.md @@ -150,6 +150,23 @@ pipeline(prompt, num_inference_steps=30).images[0] Compilation is slow the first time, but once compiled, it is significantly faster. Try to only use the compiled pipeline on the same type of inference operations. Calling the compiled pipeline on a different image size retriggers compilation which is slow and inefficient. +### Compilation on shape changes + +`torch.compile()` maintains a stack of "guards" for the shapes and conditions it sees when it is triggered. When that is violated, the compiler triggers recompilation. This means that if a model was compiled on the 1024x1024 resolution, for example, it will trigger recompilation if it is called on a different resolution. + +In these cases, it's beneficial to compile with `dynamic=True`: + +```diff ++ torch.fx.experimental._config.use_duck_shape = False ++ pipeline.unet = torch.compile( + pipeline.unet, fullgraph=True, dynamic=True +) +``` + +Make sure to always use the nightly version of PyTorch for this. Specifying `use_duck_shape` to be `False` instructs the compiler if it should use the same symbolic variable to represent input sizes that are the same. For more details, check out [this comment](https://github.com/huggingface/diffusers/pull/11327#discussion_r2047659790). + +All models might not benefit from this out of the box and may require changes. Refer to [this PR](https://github.com/huggingface/diffusers/pull/11297/) that improved the implementation of [`AuraFlowPipeline`] to benefit from compilation with `dynamic=True`. Feel free to open an issue if dynamic compilation doesn't work expected for a model inside Diffusers. + ### Regional compilation [Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) reduces the cold start compilation time by only compiling a specific repeated region (or block) of the model instead of the entire model. The compiler reuses the cached and compiled code for the other blocks. diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index e9022d128c7a..e86c988a8cdc 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1940,14 +1940,14 @@ def test_torch_compile_recompilation_and_graph_break(self): _ = model(**inputs_dict) def test_compile_with_group_offloading(self): + if not self.model_class._supports_group_offloading: + pytest.skip("Model does not support group offloading.") + torch._dynamo.config.cache_size_limit = 10000 init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict) - if not getattr(model, "_supports_group_offloading", True): - return - model.eval() # TODO: Can test for other group offloading kwargs later if needed. group_offload_kwargs = { From d1b492bef0e8ccbd0538c3cd8d3b8926ca58bd72 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 25 Jun 2025 07:47:03 +0530 Subject: [PATCH 4/4] Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/optimization/fp16.md | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/docs/source/en/optimization/fp16.md b/docs/source/en/optimization/fp16.md index 2e1fd80d7e2d..9241c0e52552 100644 --- a/docs/source/en/optimization/fp16.md +++ b/docs/source/en/optimization/fp16.md @@ -150,11 +150,14 @@ pipeline(prompt, num_inference_steps=30).images[0] Compilation is slow the first time, but once compiled, it is significantly faster. Try to only use the compiled pipeline on the same type of inference operations. Calling the compiled pipeline on a different image size retriggers compilation which is slow and inefficient. -### Compilation on shape changes +### Dynamic shape compilation -`torch.compile()` maintains a stack of "guards" for the shapes and conditions it sees when it is triggered. When that is violated, the compiler triggers recompilation. This means that if a model was compiled on the 1024x1024 resolution, for example, it will trigger recompilation if it is called on a different resolution. +> [!TIP] +> Make sure to always use the nightly version of PyTorch for better support. + +`torch.compile` keeps track of input shapes and conditions, and if these are different, it recompiles the model. For example, if a model is compiled on a 1024x1024 resolution image and used on an image with a different resolution, it triggers recompilation. -In these cases, it's beneficial to compile with `dynamic=True`: +To avoid recompilation, add `dynamic=True` to try and generate a more dynamic kernel to avoid recompilation when conditions change. ```diff + torch.fx.experimental._config.use_duck_shape = False @@ -163,9 +166,11 @@ In these cases, it's beneficial to compile with `dynamic=True`: ) ``` -Make sure to always use the nightly version of PyTorch for this. Specifying `use_duck_shape` to be `False` instructs the compiler if it should use the same symbolic variable to represent input sizes that are the same. For more details, check out [this comment](https://github.com/huggingface/diffusers/pull/11327#discussion_r2047659790). +Specifying `use_duck_shape=False` instructs the compiler if it should use the same symbolic variable to represent input sizes that are the same. For more details, check out this [comment](https://github.com/huggingface/diffusers/pull/11327#discussion_r2047659790). + +Not all models may benefit from dynamic compilation out of the box and may require changes. Refer to this [PR](https://github.com/huggingface/diffusers/pull/11297/) that improved the [`AuraFlowPipeline`] implementation to benefit from dynamic compilation. -All models might not benefit from this out of the box and may require changes. Refer to [this PR](https://github.com/huggingface/diffusers/pull/11297/) that improved the implementation of [`AuraFlowPipeline`] to benefit from compilation with `dynamic=True`. Feel free to open an issue if dynamic compilation doesn't work expected for a model inside Diffusers. +Feel free to open an issue if dynamic compilation doesn't work as expected for a Diffusers model. ### Regional compilation