diff --git a/src/diffusers/models/vae_flax.py b/src/diffusers/models/vae_flax.py index e3906c09b1cc..793010e87f67 100644 --- a/src/diffusers/models/vae_flax.py +++ b/src/diffusers/models/vae_flax.py @@ -600,7 +600,8 @@ def __call__(self, sample, sample_posterior=False, deterministic: bool = True, r hidden_states = posterior.latent_dist.sample(rng) else: hidden_states = posterior.latent_dist.mode() - hidden_states = self.decode(hidden_states, return_dict=return_dict).sample + + sample = self.decode(hidden_states, return_dict=return_dict).sample if not return_dict: return (sample,)