diff --git a/examples/community/lpw_stable_diffusion.py b/examples/community/lpw_stable_diffusion.py index 8b067f93e733..fdfe5ec200e3 100644 --- a/examples/community/lpw_stable_diffusion.py +++ b/examples/community/lpw_stable_diffusion.py @@ -6,38 +6,13 @@ import torch import PIL -from diffusers.configuration_utils import FrozenDict +from diffusers import SchedulerMixin, StableDiffusionPipeline from diffusers.models import AutoencoderKL, UNet2DConditionModel -from diffusers.pipeline_utils import DiffusionPipeline -from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput -from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker -from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler -from diffusers.utils import deprecate, is_accelerate_available, logging - -# TODO: remove and import from diffusers.utils when the new version of diffusers is released -from packaging import version +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker +from diffusers.utils import PIL_INTERPOLATION, deprecate, logging from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer -if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): - PIL_INTERPOLATION = { - "linear": PIL.Image.Resampling.BILINEAR, - "bilinear": PIL.Image.Resampling.BILINEAR, - "bicubic": PIL.Image.Resampling.BICUBIC, - "lanczos": PIL.Image.Resampling.LANCZOS, - "nearest": PIL.Image.Resampling.NEAREST, - } -else: - PIL_INTERPOLATION = { - "linear": PIL.Image.LINEAR, - "bilinear": PIL.Image.BILINEAR, - "bicubic": PIL.Image.BICUBIC, - "lanczos": PIL.Image.LANCZOS, - "nearest": PIL.Image.NEAREST, - } -# ------------------------------------------------------------------------------ - - logger = logging.get_logger(__name__) # pylint: disable=invalid-name re_attention = re.compile( @@ -146,7 +121,7 @@ def multiply_range(start_position, multiplier): return res -def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_length: int): +def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], max_length: int): r""" Tokenize a list of prompts and return its tokens with weights of each token. @@ -207,7 +182,7 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_midd def get_unweighted_text_embeddings( - pipe: DiffusionPipeline, + pipe: StableDiffusionPipeline, text_input: torch.Tensor, chunk_length: int, no_boseos_middle: Optional[bool] = True, @@ -247,10 +222,10 @@ def get_unweighted_text_embeddings( def get_weighted_text_embeddings( - pipe: DiffusionPipeline, + pipe: StableDiffusionPipeline, prompt: Union[str, List[str]], uncond_prompt: Optional[Union[str, List[str]]] = None, - max_embeddings_multiples: Optional[int] = 1, + max_embeddings_multiples: Optional[int] = 3, no_boseos_middle: Optional[bool] = False, skip_parsing: Optional[bool] = False, skip_weighting: Optional[bool] = False, @@ -264,14 +239,14 @@ def get_weighted_text_embeddings( Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean. Args: - pipe (`DiffusionPipeline`): + pipe (`StableDiffusionPipeline`): Pipe to provide access to the tokenizer and the text encoder. prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. uncond_prompt (`str` or `List[str]`): The unconditional prompt or prompts for guide the image generation. If unconditional prompt is provided, the embeddings of prompt and uncond_prompt are concatenated. - max_embeddings_multiples (`int`, *optional*, defaults to `1`): + max_embeddings_multiples (`int`, *optional*, defaults to `3`): The max multiple length of prompt embeddings compared to the max output length of text encoder. no_boseos_middle (`bool`, *optional*, defaults to `False`): If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and @@ -387,11 +362,11 @@ def preprocess_image(image): return 2.0 * image - 1.0 -def preprocess_mask(mask): +def preprocess_mask(mask, scale_factor=8): mask = mask.convert("L") w, h = mask.size w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - mask = mask.resize((w // 8, h // 8), resample=PIL_INTERPOLATION["nearest"]) + mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"]) mask = np.array(mask).astype(np.float32) / 255.0 mask = np.tile(mask, (4, 1, 1)) mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? @@ -400,7 +375,7 @@ def preprocess_mask(mask): return mask -class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline): +class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): r""" Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing weighting in prompt. @@ -435,50 +410,12 @@ def __init__( text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + scheduler: SchedulerMixin, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, ): - super().__init__() - - if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" - f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " - "to update the config accordingly as leaving `steps_offset` might led to incorrect results" - " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," - " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" - " file" - ) - deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["steps_offset"] = 1 - scheduler._internal_dict = FrozenDict(new_config) - - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." - " `clip_sample` should be set to False in the configuration file. Please make sure to update the" - " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" - " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" - " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" - ) - deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["clip_sample"] = False - scheduler._internal_dict = FrozenDict(new_config) - - if safety_checker is None: - logger.warning( - f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" - " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" - " results in services or applications open to the public. Both the diffusers team and Hugging Face" - " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" - " it only for use-cases that involve analyzing network behavior or auditing its results. For more" - " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." - ) - - self.register_modules( + super().__init__( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, @@ -486,51 +423,171 @@ def __init__( scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, + requires_safety_checker=requires_safety_checker, ) - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + max_embeddings_multiples, + ): r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. + Encodes the prompt into text encoder hidden states. Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. + prompt (`str` or `list(int)`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. """ - if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 - self.unet.set_attention_slice(slice_size) + batch_size = len(prompt) if isinstance(prompt, list) else 1 - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `attention slicing` - self.enable_attention_slicing(None) + if negative_prompt is None: + negative_prompt = [""] * batch_size + elif isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * batch_size + if batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) - def enable_sequential_cpu_offload(self): - r""" - Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, - text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a - `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. - """ - if is_accelerate_available(): - from accelerate import cpu_offload + text_embeddings, uncond_embeddings = get_weighted_text_embeddings( + pipe=self, + prompt=prompt, + uncond_prompt=negative_prompt if do_classifier_free_guidance else None, + max_embeddings_multiples=max_embeddings_multiples, + ) + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + bs_embed, seq_len, _ = uncond_embeddings.shape + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + return text_embeddings + + def check_inputs(self, prompt, height, width, strength, callback_steps): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + def get_timesteps(self, num_inference_steps, strength, device, is_text2img): + if is_text2img: + return self.scheduler.timesteps.to(device), num_inference_steps + else: + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + t_start = max(num_inference_steps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:].to(device) + return timesteps, num_inference_steps - t_start + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) else: - raise ImportError("Please install accelerate via `pip install accelerate`") + has_nsfw_concept = None + return image, has_nsfw_concept + + def decode_latents(self, latents): + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] - device = self.device + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def prepare_latents(self, image, timestep, batch_size, height, width, dtype, device, generator, latents=None): + if image is None: + shape = ( + batch_size, + self.unet.in_channels, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + + if latents is None: + if device.type == "mps": + # randn does not work reproducibly on mps + latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + else: + latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents, None, None + else: + init_latent_dist = self.vae.encode(image).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + init_latents = 0.18215 * init_latents + init_latents = torch.cat([init_latents] * batch_size, dim=0) + init_latents_orig = init_latents + shape = init_latents.shape - for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]: - if cpu_offloaded_model is not None: - cpu_offload(cpu_offloaded_model, device) + # add noise to latents using the timesteps + if device.type == "mps": + noise = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + else: + noise = torch.randn(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.add_noise(init_latents, noise, timestep) + return latents, init_latents_orig, noise @torch.no_grad() def __call__( @@ -634,221 +691,111 @@ def __call__( init_image = deprecate("init_image", "0.12.0", message, take_from=kwargs) image = init_image or image - if isinstance(prompt, str): - batch_size = 1 - prompt = [prompt] - elif isinstance(prompt, list): - batch_size = len(prompt) - else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor - if strength < 0 or strength > 1: - raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - # get prompt text embeddings + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, strength, callback_steps) + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 - # get unconditional embeddings for classifier free guidance - if negative_prompt is None: - negative_prompt = [""] * batch_size - elif isinstance(negative_prompt, str): - negative_prompt = [negative_prompt] * batch_size - if batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - text_embeddings, uncond_embeddings = get_weighted_text_embeddings( - pipe=self, - prompt=prompt, - uncond_prompt=negative_prompt if do_classifier_free_guidance else None, - max_embeddings_multiples=max_embeddings_multiples, - **kwargs, + # 3. Encode input prompt + text_embeddings = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + max_embeddings_multiples, ) - bs_embed, seq_len, _ = text_embeddings.shape - text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) - text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) - - if do_classifier_free_guidance: - bs_embed, seq_len, _ = uncond_embeddings.shape - uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) - uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) - - # set timesteps - self.scheduler.set_timesteps(num_inference_steps) - - latents_dtype = text_embeddings.dtype - init_latents_orig = None - mask = None - noise = None - - if image is None: - # get the initial random noise unless the user supplied it - - # Unlike in other pipelines, latents need to be generated in the target device - # for 1-to-1 results reproducibility with the CompVis implementation. - # However this currently doesn't work in `mps`. - latents_shape = ( - batch_size * num_images_per_prompt, - self.unet.in_channels, - height // 8, - width // 8, - ) - - if latents is None: - if self.device.type == "mps": - # randn does not exist on mps - latents = torch.randn( - latents_shape, - generator=generator, - device="cpu", - dtype=latents_dtype, - ).to(self.device) - else: - latents = torch.randn( - latents_shape, - generator=generator, - device=self.device, - dtype=latents_dtype, - ) - else: - if latents.shape != latents_shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - latents = latents.to(self.device) - - timesteps = self.scheduler.timesteps.to(self.device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma + dtype = text_embeddings.dtype + + # 4. Preprocess image and mask + if isinstance(image, PIL.Image.Image): + image = preprocess_image(image) + if image is not None: + image = image.to(device=self.device, dtype=dtype) + if isinstance(mask_image, PIL.Image.Image): + mask_image = preprocess_mask(mask_image, self.vae_scale_factor) + if mask_image is not None: + mask = mask_image.to(device=self.device, dtype=dtype) + mask = torch.cat([mask] * batch_size * num_images_per_prompt) else: - if isinstance(image, PIL.Image.Image): - image = preprocess_image(image) - # encode the init image into latents and scale the latents - image = image.to(device=self.device, dtype=latents_dtype) - init_latent_dist = self.vae.encode(image).latent_dist - init_latents = init_latent_dist.sample(generator=generator) - init_latents = 0.18215 * init_latents - init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0) - init_latents_orig = init_latents - - # preprocess mask - if mask_image is not None: - if isinstance(mask_image, PIL.Image.Image): - mask_image = preprocess_mask(mask_image) - mask_image = mask_image.to(device=self.device, dtype=latents_dtype) - mask = torch.cat([mask_image] * batch_size * num_images_per_prompt) - - # check sizes - if not mask.shape == init_latents.shape: - raise ValueError("The mask and image should be the same size!") - - # get the original timestep using init_timestep - offset = self.scheduler.config.get("steps_offset", 0) - init_timestep = int(num_inference_steps * strength) + offset - init_timestep = min(init_timestep, num_inference_steps) - - timesteps = self.scheduler.timesteps[-init_timestep] - timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device) - - # add noise to latents using the timesteps - if self.device.type == "mps": - # randn does not exist on mps - noise = torch.randn( - init_latents.shape, - generator=generator, - device="cpu", - dtype=latents_dtype, - ).to(self.device) - else: - noise = torch.randn( - init_latents.shape, - generator=generator, - device=self.device, - dtype=latents_dtype, - ) - latents = self.scheduler.add_noise(init_latents, noise, timesteps) - - t_start = max(num_inference_steps - init_timestep + offset, 0) - timesteps = self.scheduler.timesteps[t_start:].to(self.device) - - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - for i, t in enumerate(self.progress_bar(timesteps)): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - if mask is not None: - # masking - init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t])) - latents = (init_latents_proper * mask) + (latents * (1 - mask)) - - # call the callback, if provided - if i % callback_steps == 0: - if callback is not None: - callback(i, t, latents) - if is_cancelled_callback is not None and is_cancelled_callback(): - return None - - latents = 1 / 0.18215 * latents - image = self.vae.decode(latents).sample - - image = (image / 2 + 0.5).clamp(0, 1) - - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to( - self.device - ) - image, has_nsfw_concept = self.safety_checker( - images=image, - clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype), - ) - else: - has_nsfw_concept = None + mask = None + + # 5. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 6. Prepare latent variables + latents, init_latents_orig, noise = self.prepare_latents( + image, + latent_timestep, + batch_size * num_images_per_prompt, + height, + width, + dtype, + device, + generator, + latents, + ) + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if mask is not None: + # masking + init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t])) + latents = (init_latents_proper * mask) + (latents * (1 - mask)) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if i % callback_steps == 0: + if callback is not None: + callback(i, t, latents) + if is_cancelled_callback is not None and is_cancelled_callback(): + return None + + # 9. Post-processing + image = self.decode_latents(latents) + + # 10. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype) + + # 11. Convert to PIL if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: - return (image, has_nsfw_concept) + return image, has_nsfw_concept return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) @@ -868,6 +815,7 @@ def text2img( output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + is_cancelled_callback: Optional[Callable[[], bool]] = None, callback_steps: Optional[int] = 1, **kwargs, ): @@ -915,6 +863,9 @@ def text2img( callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + is_cancelled_callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. If the function returns + `True`, the inference will be cancelled. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. @@ -940,6 +891,7 @@ def text2img( output_type=output_type, return_dict=return_dict, callback=callback, + is_cancelled_callback=is_cancelled_callback, callback_steps=callback_steps, **kwargs, ) @@ -959,6 +911,7 @@ def img2img( output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + is_cancelled_callback: Optional[Callable[[], bool]] = None, callback_steps: Optional[int] = 1, **kwargs, ): @@ -1007,6 +960,9 @@ def img2img( callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + is_cancelled_callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. If the function returns + `True`, the inference will be cancelled. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. @@ -1031,6 +987,7 @@ def img2img( output_type=output_type, return_dict=return_dict, callback=callback, + is_cancelled_callback=is_cancelled_callback, callback_steps=callback_steps, **kwargs, ) @@ -1051,6 +1008,7 @@ def inpaint( output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + is_cancelled_callback: Optional[Callable[[], bool]] = None, callback_steps: Optional[int] = 1, **kwargs, ): @@ -1103,6 +1061,9 @@ def inpaint( callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + is_cancelled_callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. If the function returns + `True`, the inference will be cancelled. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. @@ -1128,6 +1089,7 @@ def inpaint( output_type=output_type, return_dict=return_dict, callback=callback, + is_cancelled_callback=is_cancelled_callback, callback_steps=callback_steps, **kwargs, ) diff --git a/examples/community/lpw_stable_diffusion_onnx.py b/examples/community/lpw_stable_diffusion_onnx.py index a6e765abe66b..4abdf782072d 100644 --- a/examples/community/lpw_stable_diffusion_onnx.py +++ b/examples/community/lpw_stable_diffusion_onnx.py @@ -6,35 +6,13 @@ import torch import PIL -from diffusers.onnx_utils import OnnxRuntimeModel -from diffusers.pipeline_utils import DiffusionPipeline +from diffusers import OnnxStableDiffusionPipeline, SchedulerMixin +from diffusers.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput -from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler -from diffusers.utils import deprecate, logging - -# TODO: remove and import from diffusers.utils when the new version of diffusers is released -from packaging import version +from diffusers.utils import PIL_INTERPOLATION, deprecate, logging from transformers import CLIPFeatureExtractor, CLIPTokenizer -if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): - PIL_INTERPOLATION = { - "linear": PIL.Image.Resampling.BILINEAR, - "bilinear": PIL.Image.Resampling.BILINEAR, - "bicubic": PIL.Image.Resampling.BICUBIC, - "lanczos": PIL.Image.Resampling.LANCZOS, - "nearest": PIL.Image.Resampling.NEAREST, - } -else: - PIL_INTERPOLATION = { - "linear": PIL.Image.LINEAR, - "bilinear": PIL.Image.BILINEAR, - "bicubic": PIL.Image.BICUBIC, - "lanczos": PIL.Image.LANCZOS, - "nearest": PIL.Image.NEAREST, - } -# ------------------------------------------------------------------------------ - logger = logging.get_logger(__name__) # pylint: disable=invalid-name re_attention = re.compile( @@ -262,7 +240,7 @@ def get_weighted_text_embeddings( Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean. Args: - pipe (`DiffusionPipeline`): + pipe (`OnnxStableDiffusionPipeline`): Pipe to provide access to the tokenizer and the text encoder. prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. @@ -392,11 +370,11 @@ def preprocess_image(image): return 2.0 * image - 1.0 -def preprocess_mask(mask): +def preprocess_mask(mask, scale_factor=8): mask = mask.convert("L") w, h = mask.size w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - mask = mask.resize((w // 8, h // 8), resample=PIL_INTERPOLATION["nearest"]) + mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"]) mask = np.array(mask).astype(np.float32) / 255.0 mask = np.tile(mask, (4, 1, 1)) mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? @@ -404,7 +382,7 @@ def preprocess_mask(mask): return mask -class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline): +class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline): r""" Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing weighting in prompt. @@ -420,12 +398,12 @@ def __init__( text_encoder: OnnxRuntimeModel, tokenizer: CLIPTokenizer, unet: OnnxRuntimeModel, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + scheduler: SchedulerMixin, safety_checker: OnnxRuntimeModel, feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, ): - super().__init__() - self.register_modules( + super().__init__( vae_encoder=vae_encoder, vae_decoder=vae_decoder, text_encoder=text_encoder, @@ -434,8 +412,171 @@ def __init__( scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, + requires_safety_checker=requires_safety_checker, + ) + self.unet_in_channels = 4 + self.vae_scale_factor = 8 + + def _encode_prompt( + self, + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + max_embeddings_multiples, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + """ + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + if negative_prompt is None: + negative_prompt = [""] * batch_size + elif isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * batch_size + if batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + text_embeddings, uncond_embeddings = get_weighted_text_embeddings( + pipe=self, + prompt=prompt, + uncond_prompt=negative_prompt if do_classifier_free_guidance else None, + max_embeddings_multiples=max_embeddings_multiples, ) + text_embeddings = text_embeddings.repeat(num_images_per_prompt, 0) + if do_classifier_free_guidance: + uncond_embeddings = uncond_embeddings.repeat(num_images_per_prompt, 0) + text_embeddings = np.concatenate([uncond_embeddings, text_embeddings]) + + return text_embeddings + + def check_inputs(self, prompt, height, width, strength, callback_steps): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + def get_timesteps(self, num_inference_steps, strength, is_text2img): + if is_text2img: + return self.scheduler.timesteps, num_inference_steps + else: + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + t_start = max(num_inference_steps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:] + return timesteps, num_inference_steps - t_start + + def run_safety_checker(self, image): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor( + self.numpy_to_pil(image), return_tensors="np" + ).pixel_values.astype(image.dtype) + # There will throw an error if use safety_checker directly and batchsize>1 + images, has_nsfw_concept = [], [] + for i in range(image.shape[0]): + image_i, has_nsfw_concept_i = self.safety_checker( + clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1] + ) + images.append(image_i) + has_nsfw_concept.append(has_nsfw_concept_i[0]) + image = np.concatenate(images) + else: + has_nsfw_concept = None + return image, has_nsfw_concept + + def decode_latents(self, latents): + latents = 1 / 0.18215 * latents + # image = self.vae_decoder(latent_sample=latents)[0] + # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 + image = np.concatenate( + [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] + ) + image = np.clip(image / 2 + 0.5, 0, 1) + image = image.transpose((0, 2, 3, 1)) + return image + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def prepare_latents(self, image, timestep, batch_size, height, width, dtype, generator, latents=None): + if image is None: + shape = ( + batch_size, + self.unet_in_channels, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + + if latents is None: + latents = torch.randn(shape, generator=generator, device="cpu").numpy().astype(dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + + # scale the initial noise by the standard deviation required by the scheduler + latents = (torch.from_numpy(latents) * self.scheduler.init_noise_sigma).numpy() + return latents, None, None + else: + init_latents = self.vae_encoder(sample=image)[0] + init_latents = 0.18215 * init_latents + init_latents = np.concatenate([init_latents] * batch_size, axis=0) + init_latents_orig = init_latents + shape = init_latents.shape + + # add noise to latents using the timesteps + noise = torch.randn(shape, generator=generator, device="cpu").numpy().astype(dtype) + latents = self.scheduler.add_noise( + torch.from_numpy(init_latents), torch.from_numpy(noise), timestep + ).numpy() + return latents, init_latents_orig, noise + @torch.no_grad() def __call__( self, @@ -450,7 +591,7 @@ def __call__( strength: float = 0.8, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, - generator: Optional[np.random.RandomState] = None, + generator: Optional[torch.Generator] = None, latents: Optional[np.ndarray] = None, max_embeddings_multiples: Optional[int] = 3, output_type: Optional[str] = "pil", @@ -501,8 +642,9 @@ def __call__( eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`np.random.RandomState`, *optional*): - A np.random.RandomState to make generation deterministic. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. latents (`np.ndarray`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents @@ -537,204 +679,123 @@ def __call__( init_image = deprecate("init_image", "0.12.0", message, take_from=kwargs) image = init_image or image - if isinstance(prompt, str): - batch_size = 1 - prompt = [prompt] - elif isinstance(prompt, list): - batch_size = len(prompt) - else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if strength < 0 or strength > 1: - raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - # get prompt text embeddings + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, strength, callback_steps) + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 - # get unconditional embeddings for classifier free guidance - if negative_prompt is None: - negative_prompt = [""] * batch_size - elif isinstance(negative_prompt, str): - negative_prompt = [negative_prompt] * batch_size - if batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - if generator is None: - generator = np.random - - text_embeddings, uncond_embeddings = get_weighted_text_embeddings( - pipe=self, - prompt=prompt, - uncond_prompt=negative_prompt if do_classifier_free_guidance else None, - max_embeddings_multiples=max_embeddings_multiples, - **kwargs, + # 3. Encode input prompt + text_embeddings = self._encode_prompt( + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + max_embeddings_multiples, ) - - text_embeddings = text_embeddings.repeat(num_images_per_prompt, 0) - if do_classifier_free_guidance: - uncond_embeddings = uncond_embeddings.repeat(num_images_per_prompt, 0) - text_embeddings = np.concatenate([uncond_embeddings, text_embeddings]) - - # set timesteps - self.scheduler.set_timesteps(num_inference_steps) - - latents_dtype = text_embeddings.dtype - init_latents_orig = None - mask = None - noise = None - - if image is None: - latents_shape = ( - batch_size * num_images_per_prompt, - 4, - height // 8, - width // 8, - ) - - if latents is None: - latents = generator.randn(*latents_shape).astype(latents_dtype) - elif latents.shape != latents_shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - - timesteps = self.scheduler.timesteps.to(self.device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma + dtype = text_embeddings.dtype + + # 4. Preprocess image and mask + if isinstance(image, PIL.Image.Image): + image = preprocess_image(image) + if image is not None: + image = image.astype(dtype) + if isinstance(mask_image, PIL.Image.Image): + mask_image = preprocess_mask(mask_image, self.vae_scale_factor) + if mask_image is not None: + mask = mask_image.astype(dtype) + mask = np.concatenate([mask] * batch_size * num_images_per_prompt) else: - if isinstance(image, PIL.Image.Image): - image = preprocess_image(image) - # encode the init image into latents and scale the latents - image = image.astype(latents_dtype) - init_latents = self.vae_encoder(sample=image)[0] - init_latents = 0.18215 * init_latents - init_latents = np.concatenate([init_latents] * batch_size * num_images_per_prompt) - init_latents_orig = init_latents - - # preprocess mask - if mask_image is not None: - if isinstance(mask_image, PIL.Image.Image): - mask_image = preprocess_mask(mask_image) - mask_image = mask_image.astype(latents_dtype) - mask = np.concatenate([mask_image] * batch_size * num_images_per_prompt) - - # check sizes - if not mask.shape == init_latents.shape: - print(mask.shape, init_latents.shape) - raise ValueError("The mask and image should be the same size!") - - # get the original timestep using init_timestep - offset = self.scheduler.config.get("steps_offset", 0) - init_timestep = int(num_inference_steps * strength) + offset - init_timestep = min(init_timestep, num_inference_steps) + mask = None - timesteps = self.scheduler.timesteps[-init_timestep] - timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt) - - # add noise to latents using the timesteps - noise = generator.randn(*init_latents.shape).astype(latents_dtype) - latents = self.scheduler.add_noise( - torch.from_numpy(init_latents), torch.from_numpy(noise), timesteps - ).numpy() - - t_start = max(num_inference_steps - init_timestep + offset, 0) - timesteps = self.scheduler.timesteps[t_start:] - - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - for i, t in enumerate(self.progress_bar(timesteps)): - # expand the latents if we are doing classifier free guidance - latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - noise_pred = self.unet( - sample=latent_model_input, - timestep=np.array([t]), - encoder_hidden_states=text_embeddings, - ) - noise_pred = noise_pred[0] - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample.numpy() - - if mask is not None: - # masking - init_latents_proper = self.scheduler.add_noise( - torch.from_numpy(init_latents_orig), - torch.from_numpy(noise), - torch.tensor([t]), - ).numpy() - latents = (init_latents_proper * mask) + (latents * (1 - mask)) - - # call the callback, if provided - if i % callback_steps == 0: - if callback is not None: - callback(i, t, latents) - if is_cancelled_callback is not None and is_cancelled_callback(): - return None + # 5. set timesteps + self.scheduler.set_timesteps(num_inference_steps) + timestep_dtype = next( + (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)" + ) + timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, image is None) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 6. Prepare latent variables + latents, init_latents_orig, noise = self.prepare_latents( + image, + latent_timestep, + batch_size * num_images_per_prompt, + height, + width, + dtype, + generator, + latents, + ) - latents = 1 / 0.18215 * latents - # image = self.vae_decoder(latent_sample=latents)[0] - # it seems likes there is a problem for using half-precision vae decoder if batchsize>1 - image = [] - for i in range(latents.shape[0]): - image.append(self.vae_decoder(latent_sample=latents[i : i + 1])[0]) - image = np.concatenate(image) + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) + latent_model_input = latent_model_input.numpy() + + # predict the noise residual + noise_pred = self.unet( + sample=latent_model_input, + timestep=np.array([t], dtype=timestep_dtype), + encoder_hidden_states=text_embeddings, + ) + noise_pred = noise_pred[0] - image = np.clip(image / 2 + 0.5, 0, 1) - image = image.transpose((0, 2, 3, 1)) + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor( - self.numpy_to_pil(image), return_tensors="np" - ).pixel_values.astype(image.dtype) - # There will throw an error if use safety_checker directly and batchsize>1 - images, has_nsfw_concept = [], [] - for i in range(image.shape[0]): - image_i, has_nsfw_concept_i = self.safety_checker( - clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1] + # compute the previous noisy sample x_t -> x_t-1 + scheduler_output = self.scheduler.step( + torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs ) - images.append(image_i) - has_nsfw_concept.append(has_nsfw_concept_i[0]) - image = np.concatenate(images) - else: - has_nsfw_concept = None - + latents = scheduler_output.prev_sample.numpy() + + if mask is not None: + # masking + init_latents_proper = self.scheduler.add_noise( + torch.from_numpy(init_latents_orig), + torch.from_numpy(noise), + t, + ).numpy() + latents = (init_latents_proper * mask) + (latents * (1 - mask)) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if i % callback_steps == 0: + if callback is not None: + callback(i, t, latents) + if is_cancelled_callback is not None and is_cancelled_callback(): + return None + # 9. Post-processing + image = self.decode_latents(latents) + + # 10. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image) + + # 11. Convert to PIL if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: - return (image, has_nsfw_concept) + return image, has_nsfw_concept return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) @@ -748,7 +809,7 @@ def text2img( guidance_scale: float = 7.5, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, - generator: Optional[np.random.RandomState] = None, + generator: Optional[torch.Generator] = None, latents: Optional[np.ndarray] = None, max_embeddings_multiples: Optional[int] = 3, output_type: Optional[str] = "pil", @@ -783,8 +844,9 @@ def text2img( eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`np.random.RandomState`, *optional*): - A np.random.RandomState to make generation deterministic. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. latents (`np.ndarray`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents @@ -839,7 +901,7 @@ def img2img( guidance_scale: Optional[float] = 7.5, num_images_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.0, - generator: Optional[np.random.RandomState] = None, + generator: Optional[torch.Generator] = None, max_embeddings_multiples: Optional[int] = 3, output_type: Optional[str] = "pil", return_dict: bool = True, @@ -878,8 +940,9 @@ def img2img( eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`np.random.RandomState`, *optional*): - A np.random.RandomState to make generation deterministic. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. max_embeddings_multiples (`int`, *optional*, defaults to `3`): The max multiple length of prompt embeddings compared to the max output length of text encoder. output_type (`str`, *optional*, defaults to `"pil"`): @@ -930,7 +993,7 @@ def inpaint( guidance_scale: Optional[float] = 7.5, num_images_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.0, - generator: Optional[np.random.RandomState] = None, + generator: Optional[torch.Generator] = None, max_embeddings_multiples: Optional[int] = 3, output_type: Optional[str] = "pil", return_dict: bool = True, @@ -973,8 +1036,9 @@ def inpaint( eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`np.random.RandomState`, *optional*): - A np.random.RandomState to make generation deterministic. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. max_embeddings_multiples (`int`, *optional*, defaults to `3`): The max multiple length of prompt embeddings compared to the max output length of text encoder. output_type (`str`, *optional*, defaults to `"pil"`):