From b516b0f4e125c230c311abc8f7c1c808c08c8f41 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 12 Aug 2024 17:02:08 +0530 Subject: [PATCH 1/2] feat: support sharding for flux. --- src/diffusers/models/transformers/transformer_flux.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 3983606e46ac..8e6d4356ed01 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -250,6 +250,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig """ _supports_gradient_checkpointing = True + _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] @register_to_config def __init__( From 704c31ea850e3d9789b009ecd7e9b45b41ffe5ba Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 12 Aug 2024 17:09:08 +0530 Subject: [PATCH 2/2] tests --- tests/models/transformers/test_models_transformer_flux.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py index d1c85537b00b..bda37621c27d 100644 --- a/tests/models/transformers/test_models_transformer_flux.py +++ b/tests/models/transformers/test_models_transformer_flux.py @@ -29,6 +29,8 @@ class FluxTransformerTests(ModelTesterMixin, unittest.TestCase): model_class = FluxTransformer2DModel main_input_name = "hidden_states" + # We override the items here because the transformer under consideration is small. + model_split_percents = [0.7, 0.6, 0.6] @property def dummy_input(self):