Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 22 additions & 11 deletions scripts/convert_stable_diffusion_checkpoint_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,15 @@ def onnx_export(


@torch.no_grad()
def convert_models(model_path: str, output_path: str, opset: int):
pipeline = StableDiffusionPipeline.from_pretrained(model_path)
def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = False):
dtype = torch.float16 if fp16 else torch.float32
if fp16 and torch.cuda.is_available():
device = "cuda"
elif fp16 and not torch.cuda.is_available():
raise ValueError("`float16` model export is only supported on GPUs with CUDA")
else:
device = "cpu"
pipeline = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device)
output_path = Path(output_path)

# TEXT ENCODER
Expand All @@ -84,7 +91,7 @@ def convert_models(model_path: str, output_path: str, opset: int):
onnx_export(
pipeline.text_encoder,
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
model_args=(text_input.input_ids.to(torch.int32)),
model_args=(text_input.input_ids.to(device=device, dtype=torch.int32)),
output_path=output_path / "text_encoder" / "model.onnx",
ordered_input_names=["input_ids"],
output_names=["last_hidden_state", "pooler_output"],
Expand All @@ -100,9 +107,9 @@ def convert_models(model_path: str, output_path: str, opset: int):
onnx_export(
pipeline.unet,
model_args=(
torch.randn(2, pipeline.unet.in_channels, 64, 64),
torch.LongTensor([0, 1]),
torch.randn(2, 77, 768),
torch.randn(2, pipeline.unet.in_channels, 64, 64).to(device=device, dtype=dtype),
torch.LongTensor([0, 1]).to(device=device),
torch.randn(2, 77, 768).to(device=device, dtype=dtype),
False,
),
output_path=unet_path,
Expand Down Expand Up @@ -139,7 +146,7 @@ def convert_models(model_path: str, output_path: str, opset: int):
vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(sample, return_dict)[0].sample()
onnx_export(
vae_encoder,
model_args=(torch.randn(1, 3, 512, 512), False),
model_args=(torch.randn(1, 3, 512, 512).to(device=device, dtype=dtype), False),
output_path=output_path / "vae_encoder" / "model.onnx",
ordered_input_names=["sample", "return_dict"],
output_names=["latent_sample"],
Expand All @@ -155,7 +162,7 @@ def convert_models(model_path: str, output_path: str, opset: int):
vae_decoder.forward = vae_encoder.decode
onnx_export(
vae_decoder,
model_args=(torch.randn(1, 4, 64, 64), False),
model_args=(torch.randn(1, 4, 64, 64).to(device=device, dtype=dtype), False),
output_path=output_path / "vae_decoder" / "model.onnx",
ordered_input_names=["latent_sample", "return_dict"],
output_names=["sample"],
Expand All @@ -171,13 +178,16 @@ def convert_models(model_path: str, output_path: str, opset: int):
safety_checker.forward = safety_checker.forward_onnx
onnx_export(
pipeline.safety_checker,
model_args=(torch.randn(1, 3, 224, 224), torch.randn(1, 512, 512, 3)),
model_args=(
torch.randn(1, 3, 224, 224).to(device=device, dtype=dtype),
torch.randn(1, 512, 512, 3).to(device=device, dtype=dtype),
),
output_path=output_path / "safety_checker" / "model.onnx",
ordered_input_names=["clip_input", "images"],
output_names=["out_images", "has_nsfw_concepts"],
dynamic_axes={
"clip_input": {0: "batch", 1: "channels", 2: "height", 3: "width"},
"images": {0: "batch", 1: "channels", 2: "height", 3: "width"},
"images": {0: "batch", 1: "height", 2: "width", 3: "channels"},
},
opset=opset,
)
Expand Down Expand Up @@ -221,7 +231,8 @@ def convert_models(model_path: str, output_path: str, opset: int):
type=int,
help="The version of the ONNX operator set to use.",
)
parser.add_argument("--fp16", action="store_true", default=False, help="Export the models in `float16` mode")

args = parser.parse_args()

convert_models(args.model_path, args.output_path, args.opset)
convert_models(args.model_path, args.output_path, args.opset, args.fp16)
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def __call__(
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0,
generator: Optional[np.random.RandomState] = None,
latents: Optional[np.ndarray] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
Expand All @@ -81,6 +83,9 @@ def __call__(
f" {type(callback_steps)}."
)

if generator is None:
generator = np.random

# get prompt text embeddings
text_inputs = self.tokenizer(
prompt,
Expand All @@ -98,6 +103,7 @@ def __call__(
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
text_embeddings = np.repeat(text_embeddings, num_images_per_prompt, axis=0)

# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
Expand Down Expand Up @@ -133,16 +139,18 @@ def __call__(
return_tensors="np",
)
uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]
uncond_embeddings = np.repeat(uncond_embeddings, num_images_per_prompt, axis=0)

# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])

# get the initial random noise unless the user supplied it
latents_shape = (batch_size, 4, height // 8, width // 8)
latents_dtype = text_embeddings.dtype
latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
if latents is None:
latents = np.random.randn(*latents_shape).astype(np.float32)
latents = generator.randn(*latents_shape).astype(latents_dtype)
elif latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")

Expand Down Expand Up @@ -185,13 +193,30 @@ def __call__(
callback(i, t, latents)

latents = 1 / 0.18215 * latents
image = self.vae_decoder(latent_sample=latents)[0]
# image = self.vae_decoder(latent_sample=latents)[0]
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
image = np.concatenate(
[self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
)

image = np.clip(image / 2 + 0.5, 0, 1)
image = image.transpose((0, 2, 3, 1))

safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np")
image, has_nsfw_concept = self.safety_checker(clip_input=safety_checker_input.pixel_values, images=image)
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(
self.numpy_to_pil(image), return_tensors="np"
).pixel_values.astype(image.dtype)
# There will throw an error if use safety_checker batchsize>1
images, has_nsfw_concept = [], []
for i in range(image.shape[0]):
image_i, has_nsfw_concept_i = self.safety_checker(
clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
)
images.append(image_i)
has_nsfw_concept.append(has_nsfw_concept_i[0])
image = np.concatenate(images)
else:
has_nsfw_concept = None

if output_type == "pil":
image = self.numpy_to_pil(image)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def __call__(
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0,
generator: Optional[np.random.RandomState] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
Expand Down Expand Up @@ -159,6 +160,8 @@ def __call__(
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`np.random.RandomState`, *optional*):
A np.random.RandomState to make generation deterministic.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
Expand Down Expand Up @@ -197,6 +200,9 @@ def __call__(
f" {type(callback_steps)}."
)

if generator is None:
generator = np.random

# set timesteps
self.scheduler.set_timesteps(num_inference_steps)

Expand Down Expand Up @@ -239,7 +245,7 @@ def __call__(
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
uncond_tokens = [negative_prompt] * batch_size
elif batch_size != len(negative_prompt):
raise ValueError("The length of `negative_prompt` should be equal to batch_size.")
else:
Expand All @@ -257,13 +263,15 @@ def __call__(
uncond_embeddings = self.text_encoder(input_ids=uncond_input_ids.astype(np.int32))[0]

# duplicate unconditional embeddings for each generation per prompt
uncond_embeddings = np.repeat(uncond_embeddings, batch_size * num_images_per_prompt, axis=0)
uncond_embeddings = np.repeat(uncond_embeddings, num_images_per_prompt, axis=0)

# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])

latents_dtype = text_embeddings.dtype
init_image = init_image.astype(latents_dtype)
# encode the init image into latents and scale the latents
init_latents = self.vae_encoder(sample=init_image)[0]
init_latents = 0.18215 * init_latents
Expand Down Expand Up @@ -297,7 +305,7 @@ def __call__(
timesteps = np.array([timesteps] * batch_size * num_images_per_prompt)

# add noise to latents using the timesteps
noise = np.random.randn(*init_latents.shape).astype(np.float32)
noise = generator.randn(*init_latents.shape).astype(latents_dtype)
init_latents = self.scheduler.add_noise(
torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timesteps)
)
Expand Down Expand Up @@ -341,14 +349,28 @@ def __call__(
callback(i, t, latents)

latents = 1 / 0.18215 * latents
image = self.vae_decoder(latent_sample=latents)[0]
# image = self.vae_decoder(latent_sample=latents)[0]
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
image = np.concatenate(
[self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
)

image = np.clip(image / 2 + 0.5, 0, 1)
image = image.transpose((0, 2, 3, 1))

if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np")
image, has_nsfw_concept = self.safety_checker(clip_input=safety_checker_input.pixel_values, images=image)
safety_checker_input = self.feature_extractor(
self.numpy_to_pil(image), return_tensors="np"
).pixel_values.astype(image.dtype)
# There will throw an error if use safety_checker batchsize>1
images, has_nsfw_concept = [], []
for i in range(image.shape[0]):
image_i, has_nsfw_concept_i = self.safety_checker(
clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
)
images.append(image_i)
has_nsfw_concept.append(has_nsfw_concept_i[0])
image = np.concatenate(images)
else:
has_nsfw_concept = None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@


def prepare_mask_and_masked_image(image, mask, latents_shape):
image = np.array(image.convert("RGB"))
image = np.array(image.convert("RGB").resize((latents_shape[1] * 8, latents_shape[0] * 8)))
image = image[None].transpose(0, 3, 1, 2)
image = image.astype(np.float32) / 127.5 - 1.0

image_mask = np.array(mask.convert("L"))
image_mask = np.array(mask.convert("L").resize((latents_shape[1] * 8, latents_shape[0] * 8)))
masked_image = image * (image_mask < 127.5)

mask = mask.resize((latents_shape[1], latents_shape[0]), PIL.Image.NEAREST)
Expand Down Expand Up @@ -138,6 +138,7 @@ def __call__(
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[np.random.RandomState] = None,
latents: Optional[np.ndarray] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
Expand Down Expand Up @@ -180,6 +181,8 @@ def __call__(
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`np.random.RandomState`, *optional*):
A np.random.RandomState to make generation deterministic.
latents (`np.ndarray`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
Expand Down Expand Up @@ -222,6 +225,9 @@ def __call__(
f" {type(callback_steps)}."
)

if generator is None:
generator = np.random

# set timesteps
self.scheduler.set_timesteps(num_inference_steps)

Expand Down Expand Up @@ -261,7 +267,7 @@ def __call__(
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
uncond_tokens = [negative_prompt] * batch_size
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
Expand All @@ -283,7 +289,7 @@ def __call__(
uncond_embeddings = self.text_encoder(input_ids=uncond_input_ids.astype(np.int32))[0]

# duplicate unconditional embeddings for each generation per prompt
uncond_embeddings = np.repeat(uncond_embeddings, batch_size * num_images_per_prompt, axis=0)
uncond_embeddings = np.repeat(uncond_embeddings, num_images_per_prompt, axis=0)

# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
Expand All @@ -294,7 +300,7 @@ def __call__(
latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8)
latents_dtype = text_embeddings.dtype
if latents is None:
latents = np.random.randn(*latents_shape).astype(latents_dtype)
latents = generator.randn(*latents_shape).astype(latents_dtype)
else:
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
Expand All @@ -307,6 +313,10 @@ def __call__(
masked_image_latents = self.vae_encoder(sample=masked_image)[0]
masked_image_latents = 0.18215 * masked_image_latents

# duplicate mask and masked_image_latents for each generation per prompt
mask = mask.repeat(batch_size * num_images_per_prompt, 0)
masked_image_latents = masked_image_latents.repeat(batch_size * num_images_per_prompt, 0)

mask = np.concatenate([mask] * 2) if do_classifier_free_guidance else mask
masked_image_latents = (
np.concatenate([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
Expand Down Expand Up @@ -367,14 +377,28 @@ def __call__(
callback(i, t, latents)

latents = 1 / 0.18215 * latents
image = self.vae_decoder(latent_sample=latents)[0]
# image = self.vae_decoder(latent_sample=latents)[0]
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
image = np.concatenate(
[self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
)

image = np.clip(image / 2 + 0.5, 0, 1)
image = image.transpose((0, 2, 3, 1))

if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np")
image, has_nsfw_concept = self.safety_checker(clip_input=safety_checker_input.pixel_values, images=image)
safety_checker_input = self.feature_extractor(
self.numpy_to_pil(image), return_tensors="np"
).pixel_values.astype(image.dtype)
# There will throw an error if use safety_checker batchsize>1
images, has_nsfw_concept = [], []
for i in range(image.shape[0]):
image_i, has_nsfw_concept_i = self.safety_checker(
clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
)
images.append(image_i)
has_nsfw_concept.append(has_nsfw_concept_i[0])
image = np.concatenate(images)
else:
has_nsfw_concept = None

Expand Down