From 2d4cfd4cc4bc95e4c8996775b851f08e498fd5c3 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 21 Sep 2022 05:13:56 +0000 Subject: [PATCH 1/3] Fix typo in docstring. --- .../stable_diffusion/pipeline_flax_stable_diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index 7f068d7183ba..7497e6dd0b28 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -30,7 +30,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`FlaxSchedulerMixin`]): + scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], or [`FlaxPNDMScheduler`]. safety_checker ([`FlaxStableDiffusionSafetyChecker`]): From 79d6047d6be9adf1f8dd3f50a2abd46e6b2f29b3 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 20 Sep 2022 07:17:14 +0000 Subject: [PATCH 2/3] Allow dtype to be overridden on model load. This may be a temporary solution until #567 is addressed. --- src/diffusers/configuration_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 2ab85ecee16d..1c5c3d7afd58 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -154,9 +154,12 @@ def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], ret """ config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs) - init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs) + # Allow dtype to be specified on initialization + if "dtype" in unused_kwargs: + init_dict["dtype"] = unused_kwargs.pop("dtype") + model = cls(**init_dict) if return_unused_kwargs: From dabfe310088d79fb581ade1d72f59c9e0b291a29 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 21 Sep 2022 05:18:05 +0000 Subject: [PATCH 3/3] Create latents in float32 The denoising loop always computes the next step in float32, so this would fail when using `bfloat16`. --- .../stable_diffusion/pipeline_flax_stable_diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index 7497e6dd0b28..675b61266285 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -157,7 +157,7 @@ def __call__( self.unet.sample_size, ) if latents is None: - latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=self.dtype) + latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")