Skip to content

Commit 39b87b1

Browse files
authored
feat: allow flux transformer to be sharded during inference (#9159)
* feat: support sharding for flux. * tests
1 parent 3e46043 commit 39b87b1

File tree

2 files changed

+3
-0
lines changed

2 files changed

+3
-0
lines changed

src/diffusers/models/transformers/transformer_flux.py

+1
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
251251
"""
252252

253253
_supports_gradient_checkpointing = True
254+
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
254255

255256
@register_to_config
256257
def __init__(

tests/models/transformers/test_models_transformer_flux.py

+2
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
3030
model_class = FluxTransformer2DModel
3131
main_input_name = "hidden_states"
32+
# We override the items here because the transformer under consideration is small.
33+
model_split_percents = [0.7, 0.6, 0.6]
3234

3335
@property
3436
def dummy_input(self):

0 commit comments

Comments
 (0)