Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 32 additions & 23 deletions src/diffusers/pipeline_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,24 +111,27 @@ def register_modules(self, **kwargs):
from diffusers import pipelines

for name, module in kwargs.items():
# retrieve library
library = module.__module__.split(".")[0]
if module is None:
register_dict = {name: (None, None)}
else:
# retrieve library
library = module.__module__.split(".")[0]

# check if the module is a pipeline module
pipeline_dir = module.__module__.split(".")[-2]
path = module.__module__.split(".")
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
# check if the module is a pipeline module
pipeline_dir = module.__module__.split(".")[-2]
path = module.__module__.split(".")
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)

# if library is not in LOADABLE_CLASSES, then it is a custom module.
# Or if it's a pipeline module, then the module is inside the pipeline
# folder so we set the library to module name.
if library not in LOADABLE_CLASSES or is_pipeline_module:
library = pipeline_dir
# if library is not in LOADABLE_CLASSES, then it is a custom module.
# Or if it's a pipeline module, then the module is inside the pipeline
# folder so we set the library to module name.
if library not in LOADABLE_CLASSES or is_pipeline_module:
library = pipeline_dir

# retrieve class_name
class_name = module.__class__.__name__
# retrieve class_name
class_name = module.__class__.__name__

register_dict = {name: (library, class_name)}
register_dict = {name: (library, class_name)}

# save model index config
self.register_to_config(**register_dict)
Expand Down Expand Up @@ -320,6 +323,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
pipeline_class = cls
else:
diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
class_name = (
config_dict["_class_name"]
if config_dict["_class_name"].startswith("Flax")
else "Flax" + config_dict["_class_name"]
)
pipeline_class = getattr(diffusers_module, config_dict["_class_name"])

# some modules can be passed directly to the init
Expand All @@ -342,6 +350,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
for name, (library_name, class_name) in init_dict.items():
is_pipeline_module = hasattr(pipelines, library_name)
loaded_sub_model = None
sub_model_should_be_defined = True

# if the model is in a pipeline module, then we load it from the pipeline
if name in passed_class_obj:
Expand All @@ -362,6 +371,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
f" {expected_class_obj}"
)
elif passed_class_obj[name] is None:
logger.warn(
f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note"
f" that this might lead to problems when using {pipeline_class} and is not recommended."
)
sub_model_should_be_defined = False
else:
logger.warn(
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
Expand All @@ -372,25 +387,19 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
loaded_sub_model = passed_class_obj[name]
elif is_pipeline_module:
pipeline_module = getattr(pipelines, library_name)
if from_pt:
class_obj = import_flax_or_no_model(pipeline_module, class_name)
else:
class_obj = getattr(pipeline_module, class_name)
class_obj = import_flax_or_no_model(pipeline_module, class_name)

importable_classes = ALL_IMPORTABLE_CLASSES
class_candidates = {c: class_obj for c in importable_classes.keys()}
else:
# else we just import it from the library.
library = importlib.import_module(library_name)
if from_pt:
class_obj = import_flax_or_no_model(library, class_name)
else:
class_obj = getattr(library, class_name)
class_obj = import_flax_or_no_model(library, class_name)

importable_classes = LOADABLE_CLASSES[library_name]
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}

if loaded_sub_model is None:
if loaded_sub_model is None and sub_model_should_be_defined:
load_method_name = None
for class_name, class_candidate in class_candidates.items():
if issubclass(class_obj, class_candidate):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@
from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
from ...pipeline_flax_utils import FlaxDiffusionPipeline
from ...schedulers import FlaxDDIMScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler
from ...utils import logging
from . import FlaxStableDiffusionPipelineOutput
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker


logger = logging.get_logger(__name__) # pylint: disable=invalid-name


class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
r"""
Pipeline for text-to-image generation using Stable Diffusion.
Expand Down Expand Up @@ -60,6 +64,16 @@ def __init__(
super().__init__()
self.dtype = dtype

if safety_checker is None:
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)

self.register_modules(
vae=vae,
text_encoder=text_encoder,
Expand Down Expand Up @@ -265,10 +279,23 @@ def __call__(
prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
)

safety_params = params["safety_checker"]
images = (images * 255).round().astype("uint8")
images = np.asarray(images).reshape(-1, height, width, 3)
images, has_nsfw_concept = self._run_safety_checker(images, safety_params, jit)
if self.safety_checker is not None:
safety_params = params["safety_checker"]
images_uint8_casted = (images * 255).round().astype("uint8")
num_devices, batch_size = images.shape[:2]

images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3)
images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit)
images = np.asarray(images)

# block images
if any(has_nsfw_concept):
for i, is_nsfw in enumerate(has_nsfw_concept):
images[i] = np.asarray(images_uint8_casted[i])

images = images.reshape(num_devices, batch_size, height, width, 3)
else:
has_nsfw_concept = False

if not return_dict:
return (images, has_nsfw_concept)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(

if safety_checker is None:
logger.warn(
f"You have disabed the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(

if safety_checker is None:
logger.warn(
f"You have disabed the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __init__(

if safety_checker is None:
logger.warn(
f"You have disabed the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
Expand Down
100 changes: 99 additions & 1 deletion tests/test_pipelines_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

if is_flax_available():
import jax
import jax.numpy as jnp
from diffusers import FlaxStableDiffusionPipeline
from flax.jax_utils import replicate
from flax.training.common_utils import shard
Expand All @@ -34,7 +35,7 @@
class FlaxPipelineTests(unittest.TestCase):
def test_dummy_all_tpus(self):
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-pipe"
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None
)

prompt = (
Expand All @@ -57,6 +58,103 @@ def test_dummy_all_tpus(self):
prompt_ids = shard(prompt_ids)

images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images

assert images.shape == (8, 1, 64, 64, 3)
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 4.151474)) < 1e-3
assert np.abs((np.abs(images, dtype=np.float32).sum() - 49947.875)) < 1e-2

images_pil = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))

assert len(images_pil) == 8

def test_stable_diffusion_v1_4(self):
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="flax", safety_checker=None
)

prompt = (
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
" field, close up, split lighting, cinematic"
)

prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 50

num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)

p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))

# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, 8)
prompt_ids = shard(prompt_ids)

images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images

images_pil = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
for i, image in enumerate(images_pil):
image.save(f"/home/patrick/images/flax-test-{i}_fp32.png")

assert images.shape == (8, 1, 512, 512, 3)
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.05652401)) < 1e-3
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2383808.2)) < 1e-2

def test_stable_diffusion_v1_4_bfloat_16(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no safety + outer pmap

pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="bf16", dtype=jnp.bfloat16, safety_checker=None
)

prompt = (
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
" field, close up, split lighting, cinematic"
)

prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 50

num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)

p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))

# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, 8)
prompt_ids = shard(prompt_ids)

images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images

assert images.shape == (8, 1, 512, 512, 3)
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.06652832)) < 1e-3
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2384849.8)) < 1e-2

def test_stable_diffusion_v1_4_bfloat_16_with_safety(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with safety + inner pmap

pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="bf16", dtype=jnp.bfloat16
)

prompt = (
"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
" field, close up, split lighting, cinematic"
)

prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 50

num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)

# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, 8)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to replace 8 with jax.device_count(), no?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 agree

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this would indeed be a good idea (in case someone wants to open a PR for it, please feel free)

prompt_ids = shard(prompt_ids)

images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images

assert images.shape == (8, 1, 512, 512, 3)
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.06652832)) < 1e-3
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2384849.8)) < 1e-2