diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker.py b/src/diffusers/pipelines/stable_diffusion/safety_checker.py index 84b8aeb7bcde..38c7b22d08d4 100644 --- a/src/diffusers/pipelines/stable_diffusion/safety_checker.py +++ b/src/diffusers/pipelines/stable_diffusion/safety_checker.py @@ -85,7 +85,10 @@ def forward(self, clip_input, images): for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): if has_nsfw_concept: - images[idx] = np.zeros(images[idx].shape) # black image + if torch.is_tensor(images) or torch.is_tensor(images[0]): + images[idx] = torch.zeros_like(images[idx]) # black image + else: + images[idx] = np.zeros(images[idx].shape) # black image if any(has_nsfw_concepts): logger.warning( diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py index 127b1c216549..0e2c4acb5484 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py @@ -453,6 +453,20 @@ def test_stable_diffusion_img2img_pipeline_multiple_of_8(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-3 + def test_img2img_safety_checker_works(self): + sd_pipe = StableDiffusionImg2ImgPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") + sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_inputs(torch_device) + inputs["num_inference_steps"] = 20 + # make sure the safety checker is activated + inputs["prompt"] = "naked, sex, porn" + out = sd_pipe(**inputs) + + assert out.nsfw_content_detected[0], f"Safety checker should work for prompt: {inputs['prompt']}" + assert np.abs(out.images[0]).sum() < 1e-5 # should be all zeros + @nightly @require_torch_gpu