-
Notifications
You must be signed in to change notification settings - Fork 6.6k
Flax safety checker #825
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flax safety checker #825
Changes from all commits
3c838a2
a444010
3ca68c4
9d84107
750e20f
dcd27fd
a0680ed
d65d1a2
4239bff
866600b
86cb5a1
1cd8bb5
fe2817b
b255e9a
2533c50
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,8 +1,14 @@ | ||
| from functools import partial | ||
| from typing import Dict, List, Optional, Union | ||
|
|
||
| import numpy as np | ||
|
|
||
| import jax | ||
| import jax.numpy as jnp | ||
| from flax.core.frozen_dict import FrozenDict | ||
| from flax.jax_utils import unreplicate | ||
| from flax.training.common_utils import shard | ||
| from PIL import Image | ||
| from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel | ||
|
|
||
| from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel | ||
|
|
@@ -77,60 +83,44 @@ def prepare_inputs(self, prompt: Union[str, List[str]]): | |
| ) | ||
| return text_input.input_ids | ||
|
|
||
| def __call__( | ||
| def _get_safety_scores(self, features, params): | ||
| special_cos_dist, cos_dist = self.safety_checker(features, params) | ||
| return (special_cos_dist, cos_dist) | ||
|
|
||
| def _run_safety_checker(self, images, safety_model_params, jit=False): | ||
| # safety_model_params should already be replicated when jit is True | ||
| pil_images = [Image.fromarray(image) for image in images] | ||
| features = self.feature_extractor(pil_images, return_tensors="np").pixel_values | ||
|
|
||
| if jit: | ||
| features = shard(features) | ||
| special_cos_dist, cos_dist = _p_get_safety_scores(self, features, safety_model_params) | ||
| special_cos_dist = unshard(special_cos_dist) | ||
| cos_dist = unshard(cos_dist) | ||
| safety_model_params = unreplicate(safety_model_params) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we need to do
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because if we are using |
||
| else: | ||
| special_cos_dist, cos_dist = self._get_safety_scores(features, safety_model_params) | ||
|
|
||
| images, has_nsfw = self.safety_checker.filtered_with_scores( | ||
| special_cos_dist, | ||
| cos_dist, | ||
| images, | ||
| safety_model_params, | ||
| ) | ||
| return images, has_nsfw | ||
|
|
||
| def _generate( | ||
| self, | ||
| prompt_ids: jnp.array, | ||
| params: Union[Dict, FrozenDict], | ||
| prng_seed: jax.random.PRNGKey, | ||
| num_inference_steps: Optional[int] = 50, | ||
| height: Optional[int] = 512, | ||
| width: Optional[int] = 512, | ||
| guidance_scale: Optional[float] = 7.5, | ||
| num_inference_steps: int = 50, | ||
| height: int = 512, | ||
| width: int = 512, | ||
| guidance_scale: float = 7.5, | ||
| latents: Optional[jnp.array] = None, | ||
| return_dict: bool = True, | ||
| debug: bool = False, | ||
| **kwargs, | ||
| ): | ||
| r""" | ||
| Function invoked when calling the pipeline for generation. | ||
|
|
||
| Args: | ||
| prompt (`str` or `List[str]`): | ||
| The prompt or prompts to guide the image generation. | ||
| height (`int`, *optional*, defaults to 512): | ||
| The height in pixels of the generated image. | ||
| width (`int`, *optional*, defaults to 512): | ||
| The width in pixels of the generated image. | ||
| num_inference_steps (`int`, *optional*, defaults to 50): | ||
| The number of denoising steps. More denoising steps usually lead to a higher quality image at the | ||
| expense of slower inference. | ||
| guidance_scale (`float`, *optional*, defaults to 7.5): | ||
| Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). | ||
| `guidance_scale` is defined as `w` of equation 2. of [Imagen | ||
| Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > | ||
| 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, | ||
| usually at the expense of lower image quality. | ||
| generator (`torch.Generator`, *optional*): | ||
| A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation | ||
| deterministic. | ||
| latents (`jnp.array`, *optional*): | ||
| Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image | ||
| generation. Can be used to tweak the same generation with different prompts. If not provided, a latents | ||
| tensor will ge generated by sampling using the supplied random `generator`. | ||
| output_type (`str`, *optional*, defaults to `"pil"`): | ||
| The output format of the generate image. Choose between | ||
| [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. | ||
| return_dict (`bool`, *optional*, defaults to `True`): | ||
| Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of | ||
| a plain tuple. | ||
|
|
||
| Returns: | ||
| [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: | ||
| [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a | ||
| `tuple. When returning a tuple, the first element is a list with the generated images, and the second | ||
| element is a list of `bool`s denoting whether the corresponding generated image likely represents | ||
| "not-safe-for-work" (nsfw) content, according to the `safety_checker`. | ||
| """ | ||
| if height % 8 != 0 or width % 8 != 0: | ||
| raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") | ||
|
|
||
|
|
@@ -203,21 +193,106 @@ def loop_body(step, args): | |
|
|
||
| # scale and decode the image latents with vae | ||
| latents = 1 / 0.18215 * latents | ||
| # TODO: check when flax vae gets merged into main | ||
| image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample | ||
|
|
||
| image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) | ||
| return image | ||
|
|
||
| # image = jnp.asarray(image).transpose(0, 2, 3, 1) | ||
| # run safety checker | ||
| # TODO: check when flax safety checker gets merged into main | ||
| # safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np") | ||
| # image, has_nsfw_concept = self.safety_checker( | ||
| # images=image, clip_input=safety_checker_input.pixel_values, params=params["safety_params"] | ||
| # ) | ||
| has_nsfw_concept = False | ||
| def __call__( | ||
| self, | ||
| prompt_ids: jnp.array, | ||
| params: Union[Dict, FrozenDict], | ||
| prng_seed: jax.random.PRNGKey, | ||
| num_inference_steps: int = 50, | ||
| height: int = 512, | ||
| width: int = 512, | ||
| guidance_scale: float = 7.5, | ||
| latents: jnp.array = None, | ||
| return_dict: bool = True, | ||
| jit: bool = False, | ||
| debug: bool = False, | ||
| **kwargs, | ||
| ): | ||
| r""" | ||
| Function invoked when calling the pipeline for generation. | ||
|
|
||
| Args: | ||
| prompt (`str` or `List[str]`): | ||
| The prompt or prompts to guide the image generation. | ||
| height (`int`, *optional*, defaults to 512): | ||
| The height in pixels of the generated image. | ||
| width (`int`, *optional*, defaults to 512): | ||
| The width in pixels of the generated image. | ||
| num_inference_steps (`int`, *optional*, defaults to 50): | ||
| The number of denoising steps. More denoising steps usually lead to a higher quality image at the | ||
| expense of slower inference. | ||
| guidance_scale (`float`, *optional*, defaults to 7.5): | ||
| Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). | ||
| `guidance_scale` is defined as `w` of equation 2. of [Imagen | ||
| Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > | ||
| 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, | ||
| usually at the expense of lower image quality. | ||
| generator (`torch.Generator`, *optional*): | ||
| A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation | ||
| deterministic. | ||
| latents (`jnp.array`, *optional*): | ||
| Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image | ||
| generation. Can be used to tweak the same generation with different prompts. If not provided, a latents | ||
| tensor will ge generated by sampling using the supplied random `generator`. | ||
| output_type (`str`, *optional*, defaults to `"pil"`): | ||
| The output format of the generate image. Choose between | ||
| [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. | ||
| jit (`bool`, defaults to `False`): | ||
| Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument | ||
| exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release. | ||
| return_dict (`bool`, *optional*, defaults to `True`): | ||
| Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of | ||
| a plain tuple. | ||
|
|
||
| Returns: | ||
| [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: | ||
| [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a | ||
| `tuple. When returning a tuple, the first element is a list with the generated images, and the second | ||
| element is a list of `bool`s denoting whether the corresponding generated image likely represents | ||
| "not-safe-for-work" (nsfw) content, according to the `safety_checker`. | ||
| """ | ||
| if jit: | ||
| images = _p_generate( | ||
| self, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug | ||
| ) | ||
| else: | ||
| images = self._generate( | ||
| prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug | ||
| ) | ||
|
|
||
| safety_params = params["safety_checker"] | ||
| images = (images * 255).round().astype("uint8") | ||
| images = np.asarray(images).reshape(-1, height, width, 3) | ||
| images, has_nsfw_concept = self._run_safety_checker(images, safety_params, jit) | ||
|
|
||
| if not return_dict: | ||
| return (image, has_nsfw_concept) | ||
| return (images, has_nsfw_concept) | ||
|
|
||
| return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) | ||
|
|
||
|
|
||
| # TODO: maybe use a config dict instead of so many static argnums | ||
| @partial(jax.pmap, static_broadcasted_argnums=(0, 4, 5, 6, 7, 9)) | ||
| def _p_generate( | ||
| pipe, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug | ||
| ): | ||
| return pipe._generate( | ||
| prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug | ||
| ) | ||
|
|
||
|
|
||
| @partial(jax.pmap, static_broadcasted_argnums=(0,)) | ||
| def _p_get_safety_scores(pipe, features, params): | ||
| return pipe._get_safety_scores(features, params) | ||
|
Comment on lines
+279
to
+291
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (nit) maybe have this as |
||
|
|
||
|
|
||
| return FlaxStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) | ||
| def unshard(x: jnp.ndarray): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's maybe also make it private |
||
| # einops.rearrange(x, 'd b ... -> (d b) ...') | ||
| num_devices, batch_size = x.shape[:2] | ||
| rest = x.shape[2:] | ||
| return x.reshape(num_devices * batch_size, *rest) | ||
Uh oh!
There was an error while loading. Please reload this page.