Skip to content

Commit f0ab5e9

Browse files
[Bug fix] Fix img2img processor with safety checker (#3127)
Fix img2img processor with safety checker
1 parent d12119e commit f0ab5e9

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

src/diffusers/pipelines/stable_diffusion/safety_checker.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,10 @@ def forward(self, clip_input, images):
8585

8686
for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
8787
if has_nsfw_concept:
88-
images[idx] = np.zeros(images[idx].shape) # black image
88+
if torch.is_tensor(images) or torch.is_tensor(images[0]):
89+
images[idx] = torch.zeros_like(images[idx]) # black image
90+
else:
91+
images[idx] = np.zeros(images[idx].shape) # black image
8992

9093
if any(has_nsfw_concepts):
9194
logger.warning(

tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,20 @@ def test_stable_diffusion_img2img_pipeline_multiple_of_8(self):
453453

454454
assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-3
455455

456+
def test_img2img_safety_checker_works(self):
457+
sd_pipe = StableDiffusionImg2ImgPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
458+
sd_pipe.to(torch_device)
459+
sd_pipe.set_progress_bar_config(disable=None)
460+
461+
inputs = self.get_inputs(torch_device)
462+
inputs["num_inference_steps"] = 20
463+
# make sure the safety checker is activated
464+
inputs["prompt"] = "naked, sex, porn"
465+
out = sd_pipe(**inputs)
466+
467+
assert out.nsfw_content_detected[0], f"Safety checker should work for prompt: {inputs['prompt']}"
468+
assert np.abs(out.images[0]).sum() < 1e-5 # should be all zeros
469+
456470

457471
@nightly
458472
@require_torch_gpu

0 commit comments

Comments
 (0)