diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index 49cefcd8a142..49b6330574a7 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -619,6 +619,13 @@ def __init__( self.gradient_checkpointing = False def forward(self, x, feat_cache=None, feat_idx=[0]): + if torch.is_grad_enabled() and self.gradient_checkpointing: + return self._gradient_checkpointing_func(self._decode, x, feat_cache, feat_idx) + else: + return self._decode(x, feat_cache, feat_idx) + + def _decode(self, x, in_cache=None, feat_idx=[0]): + feat_cache = in_cache.copy() ## conv1 if feat_cache is not None: idx = feat_idx[0] @@ -653,7 +660,8 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): feat_idx[0] += 1 else: x = self.conv_out(x) - return x + feat_idx[0] = 0 + return x, feat_cache class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin): @@ -665,7 +673,7 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin): for all models (such as downloading or saving). """ - _supports_gradient_checkpointing = False + _supports_gradient_checkpointing = True @register_to_config def __init__( @@ -884,9 +892,13 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True): for i in range(num_frame): self._conv_idx = [0] if i == 0: - out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + out, self._feat_map = self.decoder( + x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx + ) else: - out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + out_, self._feat_map = self.decoder( + x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx + ) out = torch.cat([out, out_], 2) out = torch.clamp(out, min=-1.0, max=1.0) diff --git a/tests/models/autoencoders/test_models_autoencoder_wan.py b/tests/models/autoencoders/test_models_autoencoder_wan.py index c0af4f5834b7..43f710e9d012 100644 --- a/tests/models/autoencoders/test_models_autoencoder_wan.py +++ b/tests/models/autoencoders/test_models_autoencoder_wan.py @@ -140,7 +140,13 @@ def test_enable_disable_slicing(self): @unittest.skip("Gradient checkpointing has not been implemented yet") def test_gradient_checkpointing_is_applied(self): - pass + expected_set = { + "WanDecoder3d", + "WanEncoder3d", + "WanMidBlock", + "WanUpBlock", + } + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) @unittest.skip("Test not supported") def test_forward_with_norm_groups(self):