Skip to content
Merged
2 changes: 1 addition & 1 deletion src/diffusers/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None):
```"""
return self._cast_floating_to(params, jnp.float16, mask)

def init_weights(self, rng: jax.random.PRNGKey) -> Dict:
def init_weights(self, rng: jax.random.KeyArray) -> Dict:
raise NotImplementedError(f"init_weights method has to be implemented for {self}")

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/unet_2d_condition_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
flip_sin_to_cos: bool = True
freq_shift: int = 0

def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:
def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
# init input tensors
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/vae_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,7 @@ def setup(self):
dtype=self.dtype,
)

def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:
def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
# init input tensors
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
Expand Down
47 changes: 46 additions & 1 deletion src/diffusers/pipeline_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import importlib
import inspect
import os
from typing import Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union

import numpy as np

Expand Down Expand Up @@ -475,6 +475,51 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
model = pipeline_class(**init_kwargs, dtype=dtype)
return model, params

@staticmethod
def _get_signature_keys(obj):
parameters = inspect.signature(obj.__init__).parameters
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
expected_modules = set(required_parameters.keys()) - set(["self"])
return expected_modules, optional_parameters

@property
def components(self) -> Dict[str, Any]:
r"""
The `self.components` property can be useful to run different pipelines with the same weights and
configurations to not have to re-allocate memory.
Examples:
```py
>>> from diffusers import (
... FlaxStableDiffusionPipeline,
... FlaxStableDiffusionImg2ImgPipeline,
... )
>>> text2img = FlaxStableDiffusionPipeline.from_pretrained(
... "runwayml/stable-diffusion-v1-5", revision="bf16", dtype=jnp.bfloat16
... )
>>> img2img = FlaxStableDiffusionImg2ImgPipeline(**text2img.components)
```
Returns:
A dictionary containing all the modules needed to initialize the pipeline.
"""
expected_modules, optional_parameters = self._get_signature_keys(self)
components = {
k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters
}

if set(components.keys()) != expected_modules:
raise ValueError(
f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected"
f" {expected_modules} to be defined, but {components} are defined."
)

return components

@staticmethod
def numpy_to_pil(images):
"""
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,7 @@ def components(self) -> Dict[str, Any]:
```
Returns:
A dictionaly containing all the modules needed to initialize the pipeline.
A dictionary containing all the modules needed to initialize the pipeline.
"""
expected_modules, optional_parameters = self._get_signature_keys(self)
components = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,18 +184,14 @@ def _generate(
self,
prompt_ids: jnp.array,
params: Union[Dict, FrozenDict],
prng_seed: jax.random.PRNGKey,
num_inference_steps: int = 50,
height: Optional[int] = None,
width: Optional[int] = None,
guidance_scale: float = 7.5,
prng_seed: jax.random.KeyArray,
num_inference_steps: int,
height: int,
width: int,
guidance_scale: float,
latents: Optional[jnp.array] = None,
neg_prompt_ids: jnp.array = None,
neg_prompt_ids: Optional[jnp.array] = None,
):
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor

if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")

Expand Down Expand Up @@ -281,15 +277,15 @@ def __call__(
self,
prompt_ids: jnp.array,
params: Union[Dict, FrozenDict],
prng_seed: jax.random.PRNGKey,
prng_seed: jax.random.KeyArray,
num_inference_steps: int = 50,
height: Optional[int] = None,
width: Optional[int] = None,
guidance_scale: Union[float, jnp.array] = 7.5,
latents: jnp.array = None,
neg_prompt_ids: jnp.array = None,
return_dict: bool = True,
jit: bool = False,
neg_prompt_ids: jnp.array = None,
):
r"""
Function invoked when calling the pipeline for generation.
Expand Down
Loading