Skip to content

Commit 0b42b07

Browse files
authored
[Onnx] support half-precision and fix bugs for onnx pipelines (#932)
* [Onnx] support half-precision and fix bugs for onnx pipelines * Update convert_stable_diffusion_checkpoint_to_onnx.py * style * fix has_nsfw_concept * Update convert_stable_diffusion_checkpoint_to_onnx.py * fix style
1 parent 3d02c92 commit 0b42b07

File tree

4 files changed

+112
-30
lines changed

4 files changed

+112
-30
lines changed

scripts/convert_stable_diffusion_checkpoint_to_onnx.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,15 @@ def onnx_export(
6969

7070

7171
@torch.no_grad()
72-
def convert_models(model_path: str, output_path: str, opset: int):
73-
pipeline = StableDiffusionPipeline.from_pretrained(model_path)
72+
def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = False):
73+
dtype = torch.float16 if fp16 else torch.float32
74+
if fp16 and torch.cuda.is_available():
75+
device = "cuda"
76+
elif fp16 and not torch.cuda.is_available():
77+
raise ValueError("`float16` model export is only supported on GPUs with CUDA")
78+
else:
79+
device = "cpu"
80+
pipeline = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device)
7481
output_path = Path(output_path)
7582

7683
# TEXT ENCODER
@@ -84,7 +91,7 @@ def convert_models(model_path: str, output_path: str, opset: int):
8491
onnx_export(
8592
pipeline.text_encoder,
8693
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
87-
model_args=(text_input.input_ids.to(torch.int32)),
94+
model_args=(text_input.input_ids.to(device=device, dtype=torch.int32)),
8895
output_path=output_path / "text_encoder" / "model.onnx",
8996
ordered_input_names=["input_ids"],
9097
output_names=["last_hidden_state", "pooler_output"],
@@ -100,9 +107,9 @@ def convert_models(model_path: str, output_path: str, opset: int):
100107
onnx_export(
101108
pipeline.unet,
102109
model_args=(
103-
torch.randn(2, pipeline.unet.in_channels, 64, 64),
104-
torch.LongTensor([0, 1]),
105-
torch.randn(2, 77, 768),
110+
torch.randn(2, pipeline.unet.in_channels, 64, 64).to(device=device, dtype=dtype),
111+
torch.LongTensor([0, 1]).to(device=device),
112+
torch.randn(2, 77, 768).to(device=device, dtype=dtype),
106113
False,
107114
),
108115
output_path=unet_path,
@@ -139,7 +146,7 @@ def convert_models(model_path: str, output_path: str, opset: int):
139146
vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(sample, return_dict)[0].sample()
140147
onnx_export(
141148
vae_encoder,
142-
model_args=(torch.randn(1, 3, 512, 512), False),
149+
model_args=(torch.randn(1, 3, 512, 512).to(device=device, dtype=dtype), False),
143150
output_path=output_path / "vae_encoder" / "model.onnx",
144151
ordered_input_names=["sample", "return_dict"],
145152
output_names=["latent_sample"],
@@ -155,7 +162,7 @@ def convert_models(model_path: str, output_path: str, opset: int):
155162
vae_decoder.forward = vae_encoder.decode
156163
onnx_export(
157164
vae_decoder,
158-
model_args=(torch.randn(1, 4, 64, 64), False),
165+
model_args=(torch.randn(1, 4, 64, 64).to(device=device, dtype=dtype), False),
159166
output_path=output_path / "vae_decoder" / "model.onnx",
160167
ordered_input_names=["latent_sample", "return_dict"],
161168
output_names=["sample"],
@@ -171,13 +178,16 @@ def convert_models(model_path: str, output_path: str, opset: int):
171178
safety_checker.forward = safety_checker.forward_onnx
172179
onnx_export(
173180
pipeline.safety_checker,
174-
model_args=(torch.randn(1, 3, 224, 224), torch.randn(1, 512, 512, 3)),
181+
model_args=(
182+
torch.randn(1, 3, 224, 224).to(device=device, dtype=dtype),
183+
torch.randn(1, 512, 512, 3).to(device=device, dtype=dtype),
184+
),
175185
output_path=output_path / "safety_checker" / "model.onnx",
176186
ordered_input_names=["clip_input", "images"],
177187
output_names=["out_images", "has_nsfw_concepts"],
178188
dynamic_axes={
179189
"clip_input": {0: "batch", 1: "channels", 2: "height", 3: "width"},
180-
"images": {0: "batch", 1: "channels", 2: "height", 3: "width"},
190+
"images": {0: "batch", 1: "height", 2: "width", 3: "channels"},
181191
},
182192
opset=opset,
183193
)
@@ -221,7 +231,8 @@ def convert_models(model_path: str, output_path: str, opset: int):
221231
type=int,
222232
help="The version of the ONNX operator set to use.",
223233
)
234+
parser.add_argument("--fp16", action="store_true", default=False, help="Export the models in `float16` mode")
224235

225236
args = parser.parse_args()
226237

227-
convert_models(args.model_path, args.output_path, args.opset)
238+
convert_models(args.model_path, args.output_path, args.opset, args.fp16)

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,9 @@ def __call__(
5555
num_inference_steps: Optional[int] = 50,
5656
guidance_scale: Optional[float] = 7.5,
5757
negative_prompt: Optional[Union[str, List[str]]] = None,
58+
num_images_per_prompt: Optional[int] = 1,
5859
eta: Optional[float] = 0.0,
60+
generator: Optional[np.random.RandomState] = None,
5961
latents: Optional[np.ndarray] = None,
6062
output_type: Optional[str] = "pil",
6163
return_dict: bool = True,
@@ -81,6 +83,9 @@ def __call__(
8183
f" {type(callback_steps)}."
8284
)
8385

86+
if generator is None:
87+
generator = np.random
88+
8489
# get prompt text embeddings
8590
text_inputs = self.tokenizer(
8691
prompt,
@@ -98,6 +103,7 @@ def __call__(
98103
)
99104
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
100105
text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
106+
text_embeddings = np.repeat(text_embeddings, num_images_per_prompt, axis=0)
101107

102108
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
103109
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@@ -133,16 +139,18 @@ def __call__(
133139
return_tensors="np",
134140
)
135141
uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]
142+
uncond_embeddings = np.repeat(uncond_embeddings, num_images_per_prompt, axis=0)
136143

137144
# For classifier free guidance, we need to do two forward passes.
138145
# Here we concatenate the unconditional and text embeddings into a single batch
139146
# to avoid doing two forward passes
140147
text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
141148

142149
# get the initial random noise unless the user supplied it
143-
latents_shape = (batch_size, 4, height // 8, width // 8)
150+
latents_dtype = text_embeddings.dtype
151+
latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
144152
if latents is None:
145-
latents = np.random.randn(*latents_shape).astype(np.float32)
153+
latents = generator.randn(*latents_shape).astype(latents_dtype)
146154
elif latents.shape != latents_shape:
147155
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
148156

@@ -185,13 +193,30 @@ def __call__(
185193
callback(i, t, latents)
186194

187195
latents = 1 / 0.18215 * latents
188-
image = self.vae_decoder(latent_sample=latents)[0]
196+
# image = self.vae_decoder(latent_sample=latents)[0]
197+
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
198+
image = np.concatenate(
199+
[self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
200+
)
189201

190202
image = np.clip(image / 2 + 0.5, 0, 1)
191203
image = image.transpose((0, 2, 3, 1))
192204

193-
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np")
194-
image, has_nsfw_concept = self.safety_checker(clip_input=safety_checker_input.pixel_values, images=image)
205+
if self.safety_checker is not None:
206+
safety_checker_input = self.feature_extractor(
207+
self.numpy_to_pil(image), return_tensors="np"
208+
).pixel_values.astype(image.dtype)
209+
# There will throw an error if use safety_checker batchsize>1
210+
images, has_nsfw_concept = [], []
211+
for i in range(image.shape[0]):
212+
image_i, has_nsfw_concept_i = self.safety_checker(
213+
clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
214+
)
215+
images.append(image_i)
216+
has_nsfw_concept.append(has_nsfw_concept_i[0])
217+
image = np.concatenate(images)
218+
else:
219+
has_nsfw_concept = None
195220

196221
if output_type == "pil":
197222
image = self.numpy_to_pil(image)

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def __call__(
121121
negative_prompt: Optional[Union[str, List[str]]] = None,
122122
num_images_per_prompt: Optional[int] = 1,
123123
eta: Optional[float] = 0.0,
124+
generator: Optional[np.random.RandomState] = None,
124125
output_type: Optional[str] = "pil",
125126
return_dict: bool = True,
126127
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
@@ -159,6 +160,8 @@ def __call__(
159160
eta (`float`, *optional*, defaults to 0.0):
160161
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
161162
[`schedulers.DDIMScheduler`], will be ignored for others.
163+
generator (`np.random.RandomState`, *optional*):
164+
A np.random.RandomState to make generation deterministic.
162165
output_type (`str`, *optional*, defaults to `"pil"`):
163166
The output format of the generate image. Choose between
164167
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -197,6 +200,9 @@ def __call__(
197200
f" {type(callback_steps)}."
198201
)
199202

203+
if generator is None:
204+
generator = np.random
205+
200206
# set timesteps
201207
self.scheduler.set_timesteps(num_inference_steps)
202208

@@ -239,7 +245,7 @@ def __call__(
239245
f" {type(prompt)}."
240246
)
241247
elif isinstance(negative_prompt, str):
242-
uncond_tokens = [negative_prompt]
248+
uncond_tokens = [negative_prompt] * batch_size
243249
elif batch_size != len(negative_prompt):
244250
raise ValueError("The length of `negative_prompt` should be equal to batch_size.")
245251
else:
@@ -257,13 +263,15 @@ def __call__(
257263
uncond_embeddings = self.text_encoder(input_ids=uncond_input_ids.astype(np.int32))[0]
258264

259265
# duplicate unconditional embeddings for each generation per prompt
260-
uncond_embeddings = np.repeat(uncond_embeddings, batch_size * num_images_per_prompt, axis=0)
266+
uncond_embeddings = np.repeat(uncond_embeddings, num_images_per_prompt, axis=0)
261267

262268
# For classifier free guidance, we need to do two forward passes.
263269
# Here we concatenate the unconditional and text embeddings into a single batch
264270
# to avoid doing two forward passes
265271
text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
266272

273+
latents_dtype = text_embeddings.dtype
274+
init_image = init_image.astype(latents_dtype)
267275
# encode the init image into latents and scale the latents
268276
init_latents = self.vae_encoder(sample=init_image)[0]
269277
init_latents = 0.18215 * init_latents
@@ -297,7 +305,7 @@ def __call__(
297305
timesteps = np.array([timesteps] * batch_size * num_images_per_prompt)
298306

299307
# add noise to latents using the timesteps
300-
noise = np.random.randn(*init_latents.shape).astype(np.float32)
308+
noise = generator.randn(*init_latents.shape).astype(latents_dtype)
301309
init_latents = self.scheduler.add_noise(
302310
torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timesteps)
303311
)
@@ -341,14 +349,28 @@ def __call__(
341349
callback(i, t, latents)
342350

343351
latents = 1 / 0.18215 * latents
344-
image = self.vae_decoder(latent_sample=latents)[0]
352+
# image = self.vae_decoder(latent_sample=latents)[0]
353+
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
354+
image = np.concatenate(
355+
[self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
356+
)
345357

346358
image = np.clip(image / 2 + 0.5, 0, 1)
347359
image = image.transpose((0, 2, 3, 1))
348360

349361
if self.safety_checker is not None:
350-
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np")
351-
image, has_nsfw_concept = self.safety_checker(clip_input=safety_checker_input.pixel_values, images=image)
362+
safety_checker_input = self.feature_extractor(
363+
self.numpy_to_pil(image), return_tensors="np"
364+
).pixel_values.astype(image.dtype)
365+
# There will throw an error if use safety_checker batchsize>1
366+
images, has_nsfw_concept = [], []
367+
for i in range(image.shape[0]):
368+
image_i, has_nsfw_concept_i = self.safety_checker(
369+
clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
370+
)
371+
images.append(image_i)
372+
has_nsfw_concept.append(has_nsfw_concept_i[0])
373+
image = np.concatenate(images)
352374
else:
353375
has_nsfw_concept = None
354376

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@
2323

2424

2525
def prepare_mask_and_masked_image(image, mask, latents_shape):
26-
image = np.array(image.convert("RGB"))
26+
image = np.array(image.convert("RGB").resize((latents_shape[1] * 8, latents_shape[0] * 8)))
2727
image = image[None].transpose(0, 3, 1, 2)
2828
image = image.astype(np.float32) / 127.5 - 1.0
2929

30-
image_mask = np.array(mask.convert("L"))
30+
image_mask = np.array(mask.convert("L").resize((latents_shape[1] * 8, latents_shape[0] * 8)))
3131
masked_image = image * (image_mask < 127.5)
3232

3333
mask = mask.resize((latents_shape[1], latents_shape[0]), PIL.Image.NEAREST)
@@ -138,6 +138,7 @@ def __call__(
138138
negative_prompt: Optional[Union[str, List[str]]] = None,
139139
num_images_per_prompt: Optional[int] = 1,
140140
eta: float = 0.0,
141+
generator: Optional[np.random.RandomState] = None,
141142
latents: Optional[np.ndarray] = None,
142143
output_type: Optional[str] = "pil",
143144
return_dict: bool = True,
@@ -180,6 +181,8 @@ def __call__(
180181
eta (`float`, *optional*, defaults to 0.0):
181182
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
182183
[`schedulers.DDIMScheduler`], will be ignored for others.
184+
generator (`np.random.RandomState`, *optional*):
185+
A np.random.RandomState to make generation deterministic.
183186
latents (`np.ndarray`, *optional*):
184187
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
185188
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
@@ -222,6 +225,9 @@ def __call__(
222225
f" {type(callback_steps)}."
223226
)
224227

228+
if generator is None:
229+
generator = np.random
230+
225231
# set timesteps
226232
self.scheduler.set_timesteps(num_inference_steps)
227233

@@ -261,7 +267,7 @@ def __call__(
261267
f" {type(prompt)}."
262268
)
263269
elif isinstance(negative_prompt, str):
264-
uncond_tokens = [negative_prompt]
270+
uncond_tokens = [negative_prompt] * batch_size
265271
elif batch_size != len(negative_prompt):
266272
raise ValueError(
267273
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
@@ -283,7 +289,7 @@ def __call__(
283289
uncond_embeddings = self.text_encoder(input_ids=uncond_input_ids.astype(np.int32))[0]
284290

285291
# duplicate unconditional embeddings for each generation per prompt
286-
uncond_embeddings = np.repeat(uncond_embeddings, batch_size * num_images_per_prompt, axis=0)
292+
uncond_embeddings = np.repeat(uncond_embeddings, num_images_per_prompt, axis=0)
287293

288294
# For classifier free guidance, we need to do two forward passes.
289295
# Here we concatenate the unconditional and text embeddings into a single batch
@@ -294,7 +300,7 @@ def __call__(
294300
latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8)
295301
latents_dtype = text_embeddings.dtype
296302
if latents is None:
297-
latents = np.random.randn(*latents_shape).astype(latents_dtype)
303+
latents = generator.randn(*latents_shape).astype(latents_dtype)
298304
else:
299305
if latents.shape != latents_shape:
300306
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
@@ -307,6 +313,10 @@ def __call__(
307313
masked_image_latents = self.vae_encoder(sample=masked_image)[0]
308314
masked_image_latents = 0.18215 * masked_image_latents
309315

316+
# duplicate mask and masked_image_latents for each generation per prompt
317+
mask = mask.repeat(batch_size * num_images_per_prompt, 0)
318+
masked_image_latents = masked_image_latents.repeat(batch_size * num_images_per_prompt, 0)
319+
310320
mask = np.concatenate([mask] * 2) if do_classifier_free_guidance else mask
311321
masked_image_latents = (
312322
np.concatenate([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
@@ -367,14 +377,28 @@ def __call__(
367377
callback(i, t, latents)
368378

369379
latents = 1 / 0.18215 * latents
370-
image = self.vae_decoder(latent_sample=latents)[0]
380+
# image = self.vae_decoder(latent_sample=latents)[0]
381+
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
382+
image = np.concatenate(
383+
[self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
384+
)
371385

372386
image = np.clip(image / 2 + 0.5, 0, 1)
373387
image = image.transpose((0, 2, 3, 1))
374388

375389
if self.safety_checker is not None:
376-
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np")
377-
image, has_nsfw_concept = self.safety_checker(clip_input=safety_checker_input.pixel_values, images=image)
390+
safety_checker_input = self.feature_extractor(
391+
self.numpy_to_pil(image), return_tensors="np"
392+
).pixel_values.astype(image.dtype)
393+
# There will throw an error if use safety_checker batchsize>1
394+
images, has_nsfw_concept = [], []
395+
for i in range(image.shape[0]):
396+
image_i, has_nsfw_concept_i = self.safety_checker(
397+
clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
398+
)
399+
images.append(image_i)
400+
has_nsfw_concept.append(has_nsfw_concept_i[0])
401+
image = np.concatenate(images)
378402
else:
379403
has_nsfw_concept = None
380404

0 commit comments

Comments
 (0)