Skip to content

Commit c727a6a

Browse files
authored
Finally fix the image-based SD tests (#509)
* Finally fix the image-based SD tests * Remove autocast * Remove autocast in image tests
1 parent f73ca90 commit c727a6a

File tree

2 files changed

+141
-82
lines changed

2 files changed

+141
-82
lines changed

src/diffusers/testing_utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,13 @@
22
import random
33
import unittest
44
from distutils.util import strtobool
5+
from typing import Union
56

67
import torch
78

9+
import PIL.Image
10+
import PIL.ImageOps
11+
import requests
812
from packaging import version
913

1014

@@ -59,3 +63,32 @@ def slow(test_case):
5963
6064
"""
6165
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
66+
67+
68+
def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
69+
"""
70+
Args:
71+
Loads `image` to a PIL Image.
72+
image (`str` or `PIL.Image.Image`):
73+
The image to convert to the PIL Image format.
74+
Returns:
75+
`PIL.Image.Image`: A PIL Image.
76+
"""
77+
if isinstance(image, str):
78+
if image.startswith("http://") or image.startswith("https://"):
79+
image = PIL.Image.open(requests.get(image, stream=True).raw)
80+
elif os.path.isfile(image):
81+
image = PIL.Image.open(image)
82+
else:
83+
raise ValueError(
84+
f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path"
85+
)
86+
elif isinstance(image, PIL.Image.Image):
87+
image = image
88+
else:
89+
raise ValueError(
90+
"Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image."
91+
)
92+
image = PIL.ImageOps.exif_transpose(image)
93+
image = image.convert("RGB")
94+
return image

tests/test_pipelines.py

Lines changed: 108 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import torch
2323

2424
import PIL
25-
from datasets import load_dataset
2625
from diffusers import (
2726
AutoencoderKL,
2827
DDIMPipeline,
@@ -47,7 +46,7 @@
4746
VQModel,
4847
)
4948
from diffusers.pipeline_utils import DiffusionPipeline
50-
from diffusers.testing_utils import floats_tensor, slow, torch_device
49+
from diffusers.testing_utils import floats_tensor, load_image, slow, torch_device
5150
from PIL import Image
5251
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
5352

@@ -168,7 +167,7 @@ def dummy_text_encoder(self):
168167
@property
169168
def dummy_safety_checker(self):
170169
def check(images, *args, **kwargs):
171-
return images, False
170+
return images, [False] * len(images)
172171

173172
return check
174173

@@ -708,6 +707,13 @@ def tearDown(self):
708707
gc.collect()
709708
torch.cuda.empty_cache()
710709

710+
@property
711+
def dummy_safety_checker(self):
712+
def check(images, *args, **kwargs):
713+
return images, [False] * len(images)
714+
715+
return check
716+
711717
def test_from_pretrained_save_pretrained(self):
712718
# 1. Load models
713719
model = UNet2DModel(
@@ -1139,144 +1145,164 @@ def test_stable_diffusion_memory_chunking(self):
11391145

11401146
@slow
11411147
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
1142-
def test_stable_diffusion_img2img_pipeline(self):
1143-
ds = load_dataset(
1144-
"imagefolder",
1145-
data_files={
1146-
"input": [
1147-
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
1148-
"/img2img/sketch-mountains-input.jpg"
1149-
],
1150-
"output": [
1151-
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
1152-
"/img2img/fantasy_landscape.png"
1153-
],
1154-
},
1148+
def test_stable_diffusion_text2img_pipeline(self):
1149+
expected_image = load_image(
1150+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
1151+
"/text2img/astronaut_riding_a_horse.png"
11551152
)
1153+
expected_image = np.array(expected_image, dtype=np.float32) / 255.0
1154+
1155+
model_id = "CompVis/stable-diffusion-v1-4"
1156+
pipe = StableDiffusionPipeline.from_pretrained(
1157+
model_id,
1158+
safety_checker=self.dummy_safety_checker,
1159+
use_auth_token=True,
1160+
)
1161+
pipe.to(torch_device)
1162+
pipe.set_progress_bar_config(disable=None)
1163+
pipe.enable_attention_slicing()
1164+
1165+
prompt = "astronaut riding a horse"
1166+
1167+
generator = torch.Generator(device=torch_device).manual_seed(0)
1168+
output = pipe(prompt=prompt, strength=0.75, guidance_scale=7.5, generator=generator, output_type="np")
1169+
image = output.images[0]
11561170

1157-
init_image = ds["input"]["image"][0].resize((768, 512))
1158-
output_image = ds["output"]["image"][0].resize((768, 512))
1171+
assert image.shape == (512, 512, 3)
1172+
assert np.abs(expected_image - image).max() < 1e-2
1173+
1174+
@slow
1175+
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
1176+
def test_stable_diffusion_img2img_pipeline(self):
1177+
init_image = load_image(
1178+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
1179+
"/img2img/sketch-mountains-input.jpg"
1180+
)
1181+
expected_image = load_image(
1182+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
1183+
"/img2img/fantasy_landscape.png"
1184+
)
1185+
init_image = init_image.resize((768, 512))
1186+
expected_image = np.array(expected_image, dtype=np.float32) / 255.0
11591187

11601188
model_id = "CompVis/stable-diffusion-v1-4"
11611189
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
11621190
model_id,
1191+
safety_checker=self.dummy_safety_checker,
11631192
use_auth_token=True,
11641193
)
11651194
pipe.to(torch_device)
1166-
pipe.enable_attention_slicing()
11671195
pipe.set_progress_bar_config(disable=None)
1196+
pipe.enable_attention_slicing()
11681197

11691198
prompt = "A fantasy landscape, trending on artstation"
11701199

11711200
generator = torch.Generator(device=torch_device).manual_seed(0)
1172-
with torch.autocast("cuda"):
1173-
output = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5, generator=generator)
1201+
output = pipe(
1202+
prompt=prompt,
1203+
init_image=init_image,
1204+
strength=0.75,
1205+
guidance_scale=7.5,
1206+
generator=generator,
1207+
output_type="np",
1208+
)
11741209
image = output.images[0]
11751210

1176-
expected_array = np.array(output_image) / 255.0
1177-
sampled_array = np.array(image) / 255.0
1211+
Image.fromarray((image * 255).round().astype("uint8")).save("fantasy_landscape.png")
11781212

1179-
assert sampled_array.shape == (512, 768, 3)
1180-
assert np.max(np.abs(sampled_array - expected_array)) < 1e-4
1213+
assert image.shape == (512, 768, 3)
1214+
assert np.abs(expected_image - image).max() < 1e-2
11811215

11821216
@slow
11831217
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
11841218
def test_stable_diffusion_img2img_pipeline_k_lms(self):
1185-
ds = load_dataset(
1186-
"imagefolder",
1187-
data_files={
1188-
"input": [
1189-
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
1190-
"/img2img/sketch-mountains-input.jpg"
1191-
],
1192-
"output": [
1193-
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
1194-
"/img2img/fantasy_landscape_k_lms.png"
1195-
],
1196-
},
1219+
init_image = load_image(
1220+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
1221+
"/img2img/sketch-mountains-input.jpg"
11971222
)
1198-
1199-
init_image = ds["input"]["image"][0].resize((768, 512))
1200-
output_image = ds["output"]["image"][0].resize((768, 512))
1223+
expected_image = load_image(
1224+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
1225+
"/img2img/fantasy_landscape_k_lms.png"
1226+
)
1227+
init_image = init_image.resize((768, 512))
1228+
expected_image = np.array(expected_image, dtype=np.float32) / 255.0
12011229

12021230
lms = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
12031231

12041232
model_id = "CompVis/stable-diffusion-v1-4"
12051233
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
12061234
model_id,
12071235
scheduler=lms,
1236+
safety_checker=self.dummy_safety_checker,
12081237
use_auth_token=True,
12091238
)
1210-
pipe.enable_attention_slicing()
12111239
pipe.to(torch_device)
12121240
pipe.set_progress_bar_config(disable=None)
1241+
pipe.enable_attention_slicing()
12131242

12141243
prompt = "A fantasy landscape, trending on artstation"
12151244

12161245
generator = torch.Generator(device=torch_device).manual_seed(0)
1217-
with torch.autocast("cuda"):
1218-
output = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5, generator=generator)
1246+
output = pipe(
1247+
prompt=prompt,
1248+
init_image=init_image,
1249+
strength=0.75,
1250+
guidance_scale=7.5,
1251+
generator=generator,
1252+
output_type="np",
1253+
)
12191254
image = output.images[0]
12201255

1221-
expected_array = np.array(output_image) / 255.0
1222-
sampled_array = np.array(image) / 255.0
1256+
Image.fromarray((image * 255).round().astype("uint8")).save("fantasy_landscape_k_lms.png")
12231257

1224-
assert sampled_array.shape == (512, 768, 3)
1225-
assert np.max(np.abs(sampled_array - expected_array)) < 1e-4
1258+
assert image.shape == (512, 768, 3)
1259+
assert np.abs(expected_image - image).max() < 1e-2
12261260

12271261
@slow
12281262
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
12291263
def test_stable_diffusion_inpaint_pipeline(self):
1230-
ds = load_dataset(
1231-
"imagefolder",
1232-
data_files={
1233-
"input": [
1234-
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
1235-
"/in_paint/overture-creations-5sI6fQgYIuo.png"
1236-
],
1237-
"mask": [
1238-
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
1239-
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
1240-
],
1241-
"output": [
1242-
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
1243-
"/in_paint/red_cat_sitting_on_a_parking_bench.png"
1244-
],
1245-
},
1264+
init_image = load_image(
1265+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
1266+
"/in_paint/overture-creations-5sI6fQgYIuo.png"
12461267
)
1247-
1248-
init_image = ds["input"]["image"][0].resize((768, 512))
1249-
mask_image = ds["mask"]["image"][0].resize((768, 512))
1250-
output_image = ds["output"]["image"][0].resize((768, 512))
1268+
mask_image = load_image(
1269+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
1270+
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
1271+
)
1272+
expected_image = load_image(
1273+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
1274+
"/in_paint/red_cat_sitting_on_a_park_bench.png"
1275+
)
1276+
expected_image = np.array(expected_image, dtype=np.float32) / 255.0
12511277

12521278
model_id = "CompVis/stable-diffusion-v1-4"
12531279
pipe = StableDiffusionInpaintPipeline.from_pretrained(
12541280
model_id,
1281+
safety_checker=self.dummy_safety_checker,
12551282
use_auth_token=True,
12561283
)
12571284
pipe.to(torch_device)
1258-
pipe.enable_attention_slicing()
12591285
pipe.set_progress_bar_config(disable=None)
1286+
pipe.enable_attention_slicing()
12601287

1261-
prompt = "A red cat sitting on a parking bench"
1288+
prompt = "A red cat sitting on a park bench"
12621289

12631290
generator = torch.Generator(device=torch_device).manual_seed(0)
1264-
with torch.autocast("cuda"):
1265-
output = pipe(
1266-
prompt=prompt,
1267-
init_image=init_image,
1268-
mask_image=mask_image,
1269-
strength=0.75,
1270-
guidance_scale=7.5,
1271-
generator=generator,
1272-
)
1291+
output = pipe(
1292+
prompt=prompt,
1293+
init_image=init_image,
1294+
mask_image=mask_image,
1295+
strength=0.75,
1296+
guidance_scale=7.5,
1297+
generator=generator,
1298+
output_type="np",
1299+
)
12731300
image = output.images[0]
12741301

1275-
expected_array = np.array(output_image) / 255.0
1276-
sampled_array = np.array(image) / 255.0
1302+
Image.fromarray((image * 255).round().astype("uint8")).save("red_cat_sitting_on_a_park_bench.png")
12771303

1278-
assert sampled_array.shape == (512, 768, 3)
1279-
assert np.max(np.abs(sampled_array - expected_array)) < 1e-3
1304+
assert image.shape == (512, 512, 3)
1305+
assert np.abs(expected_image - image).max() < 1e-2
12801306

12811307
@slow
12821308
def test_stable_diffusion_onnx(self):

0 commit comments

Comments
 (0)