From c2ae2d7d27094790dc8d5b694fe3604c17a9ccda Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 20 Sep 2022 19:55:46 +0200 Subject: [PATCH 1/4] use FlaxPreTrainedModel for flax safety module --- .../stable_diffusion/safety_checker_flax.py | 92 +++++++++++++------ 1 file changed, 64 insertions(+), 28 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py b/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py index de84b793a176..6949470f8624 100644 --- a/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py +++ b/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py @@ -1,4 +1,5 @@ import warnings +from typing import Optional, Tuple import numpy as np @@ -6,13 +7,9 @@ import jax.numpy as jnp from flax import linen as nn from flax.core.frozen_dict import FrozenDict -from flax.struct import field -from transformers import CLIPVisionConfig +from transformers import CLIPConfig, FlaxPreTrainedModel from transformers.models.clip.modeling_flax_clip import FlaxCLIPVisionModule -from ...configuration_utils import ConfigMixin, flax_register_to_config -from ...modeling_flax_utils import FlaxModelMixin - def jax_cosine_distance(emb_1, emb_2, eps=1e-12): norm_emb_1 = jnp.divide(emb_1.T, jnp.clip(jnp.linalg.norm(emb_1, axis=1), a_min=eps)).T @@ -20,34 +17,17 @@ def jax_cosine_distance(emb_1, emb_2, eps=1e-12): return jnp.matmul(norm_emb_1, norm_emb_2.T) -@flax_register_to_config -class FlaxStableDiffusionSafetyChecker(nn.Module, FlaxModelMixin, ConfigMixin): - projection_dim: int = 768 - # CLIPVisionConfig fields - vision_config: dict = field(default_factory=dict) +class FlaxStableDiffusionSafetyCheckerModule(nn.Module): + config: CLIPConfig dtype: jnp.dtype = jnp.float32 - def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict: - # init input tensor - input_shape = ( - 1, - self.vision_config["image_size"], - self.vision_config["image_size"], - self.vision_config["num_channels"], - ) - pixel_values = jax.random.normal(rng, input_shape) - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - return self.init(rngs, pixel_values)["params"] - def setup(self): - clip_vision_config = CLIPVisionConfig(**self.vision_config) - self.vision_model = FlaxCLIPVisionModule(clip_vision_config, dtype=self.dtype) - self.visual_projection = nn.Dense(self.projection_dim, use_bias=False, dtype=self.dtype) + self.vision_model = FlaxCLIPVisionModule(self.config.vision_config) + self.visual_projection = nn.Dense(self.config.projection_dim, use_bias=False) - self.concept_embeds = self.param("concept_embeds", jax.nn.initializers.ones, (17, self.projection_dim)) + self.concept_embeds = self.param("concept_embeds", jax.nn.initializers.ones, (17, self.config.projection_dim)) self.special_care_embeds = self.param( - "special_care_embeds", jax.nn.initializers.ones, (3, self.projection_dim) + "special_care_embeds", jax.nn.initializers.ones, (3, self.config.projection_dim) ) self.concept_embeds_weights = self.param("concept_embeds_weights", jax.nn.initializers.ones, (17,)) @@ -109,3 +89,59 @@ def filtered_with_scores(self, special_cos_dist, cos_dist, images): ) return images, has_nsfw_concepts + + +class StableDiffusionSafetyCheckerModel(FlaxPreTrainedModel): + config_class = CLIPConfig + main_input_name = "clip_input" + module_class = FlaxStableDiffusionSafetyCheckerModule + + def __init__( + self, + config: CLIPConfig, + input_shape: Optional[Tuple] = None, + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs, + ): + if input_shape is None: + input_shape = (1, 224, 224, 3) + module = self.module_class(config=config, dtype=dtype, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensor + clip_input = jax.random.normal(rng, input_shape) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + random_params = self.module.init(rngs, clip_input)["params"] + + return random_params + + def __call__( + self, + clip_input, + params: dict = None, + ): + clip_input = jnp.transpose(clip_input, (0, 2, 3, 1)) + + return self.module.apply( + {"params": params or self.params}, + jnp.array(clip_input, dtype=jnp.float32), + rngs={}, + ) + + def filtered_with_scores(self, special_cos_dist, cos_dist, images, params: dict = None): + def _filtered_with_scores(module, special_cos_dist, cos_dist, images): + return module.filtered_with_scores(special_cos_dist, cos_dist, images) + + return self.module.apply( + {"params": params or self.params}, + special_cos_dist, + cos_dist, + images, + method=_filtered_with_scores, + ) From b300c1bbdcfd9e07cf6779b35ee6af255dd10526 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 20 Sep 2022 19:57:20 +0200 Subject: [PATCH 2/4] fix name --- src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py b/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py index 6949470f8624..8dd60a907f22 100644 --- a/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py +++ b/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py @@ -91,7 +91,7 @@ def filtered_with_scores(self, special_cos_dist, cos_dist, images): return images, has_nsfw_concepts -class StableDiffusionSafetyCheckerModel(FlaxPreTrainedModel): +class StableDiffusionSafetyChecker(FlaxPreTrainedModel): config_class = CLIPConfig main_input_name = "clip_input" module_class = FlaxStableDiffusionSafetyCheckerModule From dfe1b35ecaa9dea4524603a6607accba9b2bf941 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 20 Sep 2022 19:58:13 +0200 Subject: [PATCH 3/4] fix one more --- src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py b/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py index 8dd60a907f22..0d520af6da3d 100644 --- a/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py +++ b/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py @@ -91,7 +91,7 @@ def filtered_with_scores(self, special_cos_dist, cos_dist, images): return images, has_nsfw_concepts -class StableDiffusionSafetyChecker(FlaxPreTrainedModel): +class FlaxStableDiffusionSafetyChecker(FlaxPreTrainedModel): config_class = CLIPConfig main_input_name = "clip_input" module_class = FlaxStableDiffusionSafetyCheckerModule From deff4289bc2ac185f0a961314569cad14acdaf02 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Tue, 20 Sep 2022 20:07:47 +0200 Subject: [PATCH 4/4] Apply suggestions from code review --- src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py b/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py index 0d520af6da3d..b3cd8eef02fa 100644 --- a/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py +++ b/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py @@ -23,7 +23,7 @@ class FlaxStableDiffusionSafetyCheckerModule(nn.Module): def setup(self): self.vision_model = FlaxCLIPVisionModule(self.config.vision_config) - self.visual_projection = nn.Dense(self.config.projection_dim, use_bias=False) + self.visual_projection = nn.Dense(self.config.projection_dim, use_bias=False, dtype=self.dtype) self.concept_embeds = self.param("concept_embeds", jax.nn.initializers.ones, (17, self.config.projection_dim)) self.special_care_embeds = self.param(