From 073dbd69b85e885e4ae236ef4f670ddc2cec8ba3 Mon Sep 17 00:00:00 2001 From: Asfiya Baig Date: Tue, 4 Apr 2023 11:40:47 -0700 Subject: [PATCH 1/5] Add SD/txt2img Community Pipeline to diffusers along with TensorRT utils Signed-off-by: Asfiya Baig --- .../stable_diffusion_tensorrt_txt2img.py | 532 ++++++++++++++++++ examples/community/tensorrt_utils.py | 355 ++++++++++++ 2 files changed, 887 insertions(+) create mode 100644 examples/community/stable_diffusion_tensorrt_txt2img.py create mode 100644 examples/community/tensorrt_utils.py diff --git a/examples/community/stable_diffusion_tensorrt_txt2img.py b/examples/community/stable_diffusion_tensorrt_txt2img.py new file mode 100644 index 000000000000..d26959abbe13 --- /dev/null +++ b/examples/community/stable_diffusion_tensorrt_txt2img.py @@ -0,0 +1,532 @@ +# +# Copyright 2023 The HuggingFace Inc. team. +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import numpy as np +from typing import List, Optional, Union + +import torch +import tensorrt as trt +from polygraphy import cuda +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + +from tensorrt_utils import TRT_LOGGER, build_engines, device_view, runEngine, Optimizer, BaseModel +from huggingface_hub import snapshot_download +from diffusers.utils import DIFFUSERS_CACHE, logging +from diffusers.schedulers import DDIMScheduler +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline, StableDiffusionPipelineOutput, StableDiffusionSafetyChecker + +''' +Installation instructions +python3 -m pip install --upgrade tensorrt polygraphy onnx-graphsurgeon -i https://pypi.ngc.nvidia.com +python3 -m pip install onnxruntime +''' + +EXAMPLE_DOC_STRING = """ + Examples: + ```python3 + >>> import torch + >>> from diffusers import DDIMScheduler + + >>> # Use the DDIMScheduler scheduler here instead + >>> scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1", + subfolder="scheduler") + + >>> pipe = TensorRTStableDiffusionPipeline("stabilityai/stable-diffusion-2-1", + revision='fp16', + torch_dtype=torch.float16, + scheduler=scheduler + ) + + >>> pipe = pipe.to("cuda") + + >>> prompt = "a beautiful photograph of Mt. Fuji during cherry blossom" + >>> image = pipe(prompt).images[0] + >>> image.save('tensorrt_mt_fuji.png') + ``` +""" + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +class TensorRTStableDiffusionPipeline(StableDiffusionPipeline): + r""" + Pipeline for text-to-image generation using TensorRT accelerated Stable Diffusion. + + This model inherits from [`StableDiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: DDIMScheduler, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, + stages=['clip','unet','vae'], + image_height: int = 768, + image_width: int = 768, + max_batch_size: int = 16, + # ONNX export parameters + onnx_opset: int = 17, + onnx_dir: str = 'onnx', + # TensorRT engine build parameters + engine_dir: str = 'engine', + force_engine_rebuild: bool = False, + timing_cache: str = 'timing_cache' + ): + super().__init__(vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker) + + self.vae.forward = self.vae.decode + + self.stages = stages + self.image_height, self.image_width = image_height, image_width + self.inpaint = False + self.onnx_opset = onnx_opset + self.onnx_dir = onnx_dir + self.engine_dir = engine_dir + self.force_engine_rebuild = force_engine_rebuild + self.timing_cache = timing_cache + self.build_static_batch = False + self.build_dynamic_shape = False + self.build_preview_features = False + + self.max_batch_size = max_batch_size + # TODO: Restrict batch size to 4 for larger image dimensions as a WAR for TensorRT limitation. + if self.build_dynamic_shape or self.image_height > 512 or self.image_width > 512: + self.max_batch_size = 4 + + self.stream = None # loaded in loadResources() + self.models = {} # loaded in __loadModels() + self.engine = {} # loaded in build_engines() + + def __loadModels(self): + # Load pipeline models + self.embedding_dim = self.text_encoder.config.hidden_size + models_args = {'device': self.torch_device, 'max_batch_size': self.max_batch_size, \ + 'embedding_dim': self.embedding_dim, 'inpaint': self.inpaint} + if 'clip' in self.stages: + self.models['clip'] = make_CLIP(self.text_encoder, **models_args) + if 'unet' in self.stages: + self.models['unet'] = make_UNet(self.unet, **models_args) + if 'vae' in self.stages: + self.models['vae'] = make_VAE(self.vae, **models_args) + + def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings: bool = False): + super().to(torch_device, silence_dtype_warnings) + + self.onnx_dir = os.path.join(self.cached_folder, self.onnx_dir) + self.engine_dir = os.path.join(self.cached_folder, self.engine_dir) + self.timing_cache = os.path.join(self.cached_folder, self.timing_cache) + + # set device + self.torch_device = self._execution_device + logger.warning(f"Running inference on device: {self.torch_device}") + + # load models + self.__loadModels() + + # build engines + self.engine = build_engines(self.models, self.engine_dir, self.onnx_dir, self.onnx_opset, \ + opt_image_height=self.image_height, opt_image_width=self.image_width, \ + force_engine_rebuild=self.force_engine_rebuild, static_batch=self.build_static_batch, \ + static_shape=not self.build_dynamic_shape, enable_preview=self.build_preview_features, \ + timing_cache=self.timing_cache) + + return self + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + pipe = super().from_pretrained(pretrained_model_name_or_path, **kwargs) + + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + + cls.cached_folder = ( + pretrained_model_name_or_path + if os.path.isdir(pretrained_model_name_or_path) + else snapshot_download( + pretrained_model_name_or_path, + cache_dir=cache_dir, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + ) + ) + return pipe + + def __encode_prompt( + self, + prompt, + negative_prompt + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + """ + # Tokenize prompt + text_input_ids = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ).input_ids.type(torch.int32).to(self.torch_device) + + text_input_ids_inp = device_view(text_input_ids) + # NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt + text_embeddings = runEngine(self.engine['clip'], {"input_ids": text_input_ids_inp}, self.stream)['text_embeddings'].clone() + + # Tokenize negative prompt + uncond_input_ids = self.tokenizer( + negative_prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ).input_ids.type(torch.int32).to(self.torch_device) + uncond_input_ids_inp = device_view(uncond_input_ids) + uncond_embeddings = runEngine(self.engine['clip'], {"input_ids": uncond_input_ids_inp}, self.stream)['text_embeddings'] + + # Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to(dtype=torch.float16) + + return text_embeddings + + + def __denoise_latent(self, latents, text_embeddings, timesteps=None, step_offset=0, mask=None, masked_image_latents=None): + if not isinstance(timesteps, torch.Tensor): + timesteps = self.scheduler.timesteps + for step_index, timestep in enumerate(timesteps): + # Expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep) + if isinstance(mask, torch.Tensor): + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + + # Predict the noise residual + embeddings_dtype = np.float16 + timestep_float = timestep.float() if timestep.dtype != torch.float32 else timestep + + sample_inp = device_view(latent_model_input) + timestep_inp = device_view(timestep_float) + embeddings_inp = device_view(text_embeddings) + noise_pred = runEngine(self.engine['unet'], \ + {"sample": sample_inp, "timestep": timestep_inp, "encoder_hidden_states": embeddings_inp}, \ + self.stream)['latent'] + + # Perform guidance + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample + + latents = 1. / 0.18215 * latents + return latents + + def __decode_latent(self, latents): + images = runEngine(self.engine['vae'], {"latent": device_view(latents)}, self.stream)['images'] + images = (images / 2 + 0.5).clamp(0, 1) + return images.cpu().permute(0, 2, 3, 1).float().numpy() + + def __loadResources(self, image_height, image_width, batch_size): + self.stream = cuda.Stream() + + # Allocate buffers for TensorRT engine bindings + for model_name, obj in self.models.items(): + self.engine[model_name].allocate_buffers(shape_dict=obj.get_shape_dict(batch_size, image_height, image_width), device=self.torch_device) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + + """ + self.generator = generator + self.denoising_steps = num_inference_steps + self.guidance_scale = guidance_scale + + # Pre-compute latent input scales and linear multistep coefficients + self.scheduler.set_timesteps(self.denoising_steps, device=self.torch_device) + + # Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + prompt = [prompt] + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"Expected prompt to be of type list or str but got {type(prompt)}") + + if negative_prompt is None: + negative_prompt = [""] * batch_size + + if negative_prompt is not None and isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + + assert len(prompt) == len(negative_prompt) + + if batch_size > self.max_batch_size: + raise ValueError(f"Batch size {len(prompt)} is larger than allowed {self.max_batch_size}. If dynamic shape is used, then maximum batch size is 4") + + # load resources + self.__loadResources(self.image_height, self.image_width, batch_size) + + with torch.inference_mode(), torch.autocast("cuda"), trt.Runtime(TRT_LOGGER): + # CLIP text encoder + text_embeddings = self.__encode_prompt(prompt, negative_prompt) + + # Pre-initialize latents + num_channels_latents = self.unet.in_channels + latents = self.prepare_latents(batch_size, + num_channels_latents, + self.image_height, + self.image_width, + torch.float32, + self.torch_device, + generator) + + # UNet denoiser + latents = self.__denoise_latent(latents, text_embeddings) + + # VAE decode latent + images = self.__decode_latent(latents) + + images, has_nsfw_concept = self.run_safety_checker(images, self.torch_device, text_embeddings.dtype) + images = self.numpy_to_pil(images) + return StableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) + + +class CLIP(BaseModel): + def __init__(self, + model, + device, + max_batch_size, + embedding_dim + ): + super(CLIP, self).__init__(model=model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim) + self.name = "CLIP" + + def get_input_names(self): + return ['input_ids'] + + def get_output_names(self): + return ['text_embeddings', 'pooler_output'] + + def get_dynamic_axes(self): + return { + 'input_ids': {0: 'B'}, + 'text_embeddings': {0: 'B'} + } + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): + self.check_dims(batch_size, image_height, image_width) + min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape) + return { + 'input_ids': [(min_batch, self.text_maxlen), (batch_size, self.text_maxlen), (max_batch, self.text_maxlen)] + } + + def get_shape_dict(self, batch_size, image_height, image_width): + self.check_dims(batch_size, image_height, image_width) + return { + 'input_ids': (batch_size, self.text_maxlen), + 'text_embeddings': (batch_size, self.text_maxlen, self.embedding_dim) + } + + def get_sample_input(self, batch_size, image_height, image_width): + self.check_dims(batch_size, image_height, image_width) + return torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device) + + def optimize(self, onnx_graph): + opt = Optimizer(onnx_graph) + opt.select_outputs([0]) # delete graph output#1 + opt.cleanup() + opt.fold_constants() + opt.infer_shapes() + opt.select_outputs([0], names=['text_embeddings']) # rename network output + opt_onnx_graph = opt.cleanup(return_onnx=True) + return opt_onnx_graph + +def make_CLIP(model, device, max_batch_size, embedding_dim, inpaint=False): + return CLIP(model, device=device, max_batch_size=max_batch_size, \ + embedding_dim=embedding_dim) + +class UNet(BaseModel): + def __init__(self, + model, + fp16=False, + device='cuda', + max_batch_size=16, + embedding_dim=768, + text_maxlen=77, + unet_dim=4 + ): + super(UNet, self).__init__(model=model, fp16=fp16, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim, text_maxlen=text_maxlen) + self.unet_dim = unet_dim + self.name = "UNet" + + def get_input_names(self): + return ['sample', 'timestep', 'encoder_hidden_states'] + + def get_output_names(self): + return ['latent'] + + def get_dynamic_axes(self): + return { + 'sample': {0: '2B', 2: 'H', 3: 'W'}, + 'encoder_hidden_states': {0: '2B'}, + 'latent': {0: '2B', 2: 'H', 3: 'W'} + } + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + min_batch, max_batch, _, _, _, _, min_latent_height, max_latent_height, min_latent_width, max_latent_width = \ + self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape) + return { + 'sample': [(2*min_batch, self.unet_dim, min_latent_height, min_latent_width), (2*batch_size, self.unet_dim, latent_height, latent_width), (2*max_batch, self.unet_dim, max_latent_height, max_latent_width)], + 'encoder_hidden_states': [(2*min_batch, self.text_maxlen, self.embedding_dim), (2*batch_size, self.text_maxlen, self.embedding_dim), (2*max_batch, self.text_maxlen, self.embedding_dim)] + } + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return { + 'sample': (2*batch_size, self.unet_dim, latent_height, latent_width), + 'encoder_hidden_states': (2*batch_size, self.text_maxlen, self.embedding_dim), + 'latent': (2*batch_size, 4, latent_height, latent_width) + } + + def get_sample_input(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + dtype = torch.float16 if self.fp16 else torch.float32 + return ( + torch.randn(2*batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device), + torch.tensor([1.], dtype=torch.float32, device=self.device), + torch.randn(2*batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device) + ) + +def make_UNet(model, device, max_batch_size, embedding_dim, inpaint=False): + return UNet(model, fp16=True, device=device, max_batch_size=max_batch_size, \ + embedding_dim=embedding_dim, unet_dim=(9 if inpaint else 4)) + +class VAE(BaseModel): + def __init__(self, + model, + device, + max_batch_size, + embedding_dim + ): + super(VAE, self).__init__(model=model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim) + self.name = "VAE decoder" + + def get_input_names(self): + return ['latent'] + + def get_output_names(self): + return ['images'] + + def get_dynamic_axes(self): + return { + 'latent': {0: 'B', 2: 'H', 3: 'W'}, + 'images': {0: 'B', 2: '8H', 3: '8W'} + } + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + min_batch, max_batch, _, _, _, _, min_latent_height, max_latent_height, min_latent_width, max_latent_width = \ + self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape) + return { + 'latent': [(min_batch, 4, min_latent_height, min_latent_width), (batch_size, 4, latent_height, latent_width), (max_batch, 4, max_latent_height, max_latent_width)] + } + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return { + 'latent': (batch_size, 4, latent_height, latent_width), + 'images': (batch_size, 3, image_height, image_width) + } + + def get_sample_input(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return torch.randn(batch_size, 4, latent_height, latent_width, dtype=torch.float32, device=self.device) + +def make_VAE(model, device, max_batch_size, embedding_dim, inpaint=False): + return VAE(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim) diff --git a/examples/community/tensorrt_utils.py b/examples/community/tensorrt_utils.py new file mode 100644 index 000000000000..78118a716407 --- /dev/null +++ b/examples/community/tensorrt_utils.py @@ -0,0 +1,355 @@ +# +# Copyright 2023 The HuggingFace Inc. team. +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import gc +from collections import OrderedDict +import onnx +import torch +import numpy as np +from copy import copy +import tensorrt as trt +import onnx_graphsurgeon as gs +from onnx import shape_inference +from polygraphy.backend.onnx.loader import fold_constants +from polygraphy.backend.common import bytes_from_path +from polygraphy.backend.trt import CreateConfig, Profile +from polygraphy.backend.trt import engine_from_bytes, engine_from_network, network_from_onnx_path, save_engine +from polygraphy.backend.trt import util as trt_util +from polygraphy import cuda +from diffusers.utils import logging + +TRT_LOGGER = trt.Logger(trt.Logger.ERROR) +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +# Map of numpy dtype -> torch dtype +numpy_to_torch_dtype_dict = { + np.uint8 : torch.uint8, + np.int8 : torch.int8, + np.int16 : torch.int16, + np.int32 : torch.int32, + np.int64 : torch.int64, + np.float16 : torch.float16, + np.float32 : torch.float32, + np.float64 : torch.float64, + np.complex64 : torch.complex64, + np.complex128 : torch.complex128 +} +if np.version.full_version >= "1.24.0": + numpy_to_torch_dtype_dict[np.bool_] = torch.bool +else: + numpy_to_torch_dtype_dict[np.bool] = torch.bool + +# Map of torch dtype -> numpy dtype +torch_to_numpy_dtype_dict = {value : key for (key, value) in numpy_to_torch_dtype_dict.items()} + +def device_view(t): + return cuda.DeviceView(ptr=t.data_ptr(), shape=t.shape, dtype=torch_to_numpy_dtype_dict[t.dtype]) + +class Engine(): + def __init__( + self, + engine_path, + ): + self.engine_path = engine_path + self.engine = None + self.context = None + self.buffers = OrderedDict() + self.tensors = OrderedDict() + + def __del__(self): + [buf.free() for buf in self.buffers.values() if isinstance(buf, cuda.DeviceArray) ] + del self.engine + del self.context + del self.buffers + del self.tensors + + def build(self, onnx_path, fp16, input_profile=None, enable_preview=False, enable_all_tactics=False, timing_cache=None, workspace_size=0): + logger.warning(f"Building TensorRT engine for {onnx_path}: {self.engine_path}") + p = Profile() + if input_profile: + for name, dims in input_profile.items(): + assert len(dims) == 3 + p.add(name, min=dims[0], opt=dims[1], max=dims[2]) + + config_kwargs = {} + + config_kwargs['preview_features'] = [trt.PreviewFeature.DISABLE_EXTERNAL_TACTIC_SOURCES_FOR_CORE_0805] + if enable_preview: + # Faster dynamic shapes made optional since it increases engine build time. + config_kwargs['preview_features'].append(trt.PreviewFeature.FASTER_DYNAMIC_SHAPES_0805) + if workspace_size > 0: + config_kwargs['memory_pool_limits'] = {trt.MemoryPoolType.WORKSPACE: workspace_size} + if not enable_all_tactics: + config_kwargs['tactic_sources'] = [] + + engine = engine_from_network( + network_from_onnx_path(onnx_path), + config=CreateConfig(fp16=fp16, + profiles=[p], + load_timing_cache=timing_cache, + **config_kwargs + ), + save_timing_cache=timing_cache + ) + save_engine(engine, path=self.engine_path) + + def load(self): + logger.warning(f"Loading TensorRT engine: {self.engine_path}") + self.engine = engine_from_bytes(bytes_from_path(self.engine_path)) + + def activate(self): + self.context = self.engine.create_execution_context() + + def allocate_buffers(self, shape_dict=None, device='cuda'): + for idx in range(trt_util.get_bindings_per_profile(self.engine)): + binding = self.engine[idx] + if shape_dict and binding in shape_dict: + shape = shape_dict[binding] + else: + shape = self.engine.get_binding_shape(binding) + dtype = trt.nptype(self.engine.get_binding_dtype(binding)) + if self.engine.binding_is_input(binding): + self.context.set_binding_shape(idx, shape) + tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device) + self.tensors[binding] = tensor + self.buffers[binding] = cuda.DeviceView(ptr=tensor.data_ptr(), shape=shape, dtype=dtype) + + def infer(self, feed_dict, stream): + start_binding, end_binding = trt_util.get_active_profile_bindings(self.context) + # shallow copy of ordered dict + device_buffers = copy(self.buffers) + for name, buf in feed_dict.items(): + assert isinstance(buf, cuda.DeviceView) + device_buffers[name] = buf + bindings = [0] * start_binding + [buf.ptr for buf in device_buffers.values()] + noerror = self.context.execute_async_v2(bindings=bindings, stream_handle=stream.ptr) + if not noerror: + raise ValueError(f"ERROR: inference failed.") + + return self.tensors + +class Optimizer(): + def __init__( + self, + onnx_graph, + ): + self.graph = gs.import_onnx(onnx_graph) + + def cleanup(self, return_onnx=False): + self.graph.cleanup().toposort() + if return_onnx: + return gs.export_onnx(self.graph) + + def select_outputs(self, keep, names=None): + self.graph.outputs = [self.graph.outputs[o] for o in keep] + if names: + for i, name in enumerate(names): + self.graph.outputs[i].name = name + + def fold_constants(self, return_onnx=False): + onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True) + self.graph = gs.import_onnx(onnx_graph) + if return_onnx: + return onnx_graph + + def infer_shapes(self, return_onnx=False): + onnx_graph = gs.export_onnx(self.graph) + if onnx_graph.ByteSize() > 2147483648: + raise TypeError("ERROR: model size exceeds supported 2GB limit") + else: + onnx_graph = shape_inference.infer_shapes(onnx_graph) + + self.graph = gs.import_onnx(onnx_graph) + if return_onnx: + return onnx_graph + + +class BaseModel(): + def __init__( + self, + model, + fp16=False, + device='cuda', + max_batch_size=16, + embedding_dim=768, + text_maxlen=77, + ): + self.model = model + self.name = "SD Model" + self.fp16 = fp16 + self.device = device + + self.min_batch = 1 + self.max_batch = max_batch_size + self.min_image_shape = 256 # min image resolution: 256x256 + self.max_image_shape = 1024 # max image resolution: 1024x1024 + self.min_latent_shape = self.min_image_shape // 8 + self.max_latent_shape = self.max_image_shape // 8 + + self.embedding_dim = embedding_dim + self.text_maxlen = text_maxlen + + def get_model(self): + return self.model + + def get_input_names(self): + pass + + def get_output_names(self): + pass + + def get_dynamic_axes(self): + return None + + def get_sample_input(self, batch_size, image_height, image_width): + pass + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): + return None + + def get_shape_dict(self, batch_size, image_height, image_width): + return None + + def optimize(self, onnx_graph): + opt = Optimizer(onnx_graph) + opt.cleanup() + opt.fold_constants() + opt.infer_shapes() + onnx_opt_graph = opt.cleanup(return_onnx=True) + return onnx_opt_graph + + def check_dims(self, batch_size, image_height, image_width): + assert batch_size >= self.min_batch and batch_size <= self.max_batch + assert image_height % 8 == 0 or image_width % 8 == 0 + latent_height = image_height // 8 + latent_width = image_width // 8 + assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape + assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape + return (latent_height, latent_width) + + def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_shape): + min_batch = batch_size if static_batch else self.min_batch + max_batch = batch_size if static_batch else self.max_batch + latent_height = image_height // 8 + latent_width = image_width // 8 + min_image_height = image_height if static_shape else self.min_image_shape + max_image_height = image_height if static_shape else self.max_image_shape + min_image_width = image_width if static_shape else self.min_image_shape + max_image_width = image_width if static_shape else self.max_image_shape + min_latent_height = latent_height if static_shape else self.min_latent_shape + max_latent_height = latent_height if static_shape else self.max_latent_shape + min_latent_width = latent_width if static_shape else self.min_latent_shape + max_latent_width = latent_width if static_shape else self.max_latent_shape + return (min_batch, max_batch, min_image_height, max_image_height, min_image_width, max_image_width, min_latent_height, max_latent_height, min_latent_width, max_latent_width) + + +def getOnnxPath(model_name, onnx_dir, opt=True): + return os.path.join(onnx_dir, model_name+('.opt' if opt else '')+'.onnx') + +def getEnginePath(model_name, engine_dir): + return os.path.join(engine_dir, model_name+'.plan') + +def build_engines( + models: dict, + engine_dir, + onnx_dir, + onnx_opset, + opt_image_height, + opt_image_width, + opt_batch_size=1, + force_engine_rebuild=False, + static_batch=False, + static_shape=True, + enable_preview=False, + enable_all_tactics=False, + timing_cache=None, + max_workspace_size=0 +): + built_engines = {} + if not os.path.isdir(onnx_dir): + os.makedirs(onnx_dir) + if not os.path.isdir(engine_dir): + os.makedirs(engine_dir) + + # Export models to ONNX + for model_name, model_obj in models.items(): + engine_path = getEnginePath(model_name, engine_dir) + if force_engine_rebuild or not os.path.exists(engine_path): + logger.warning("Building Engines...") + logger.warning("Engine build can take a while to complete") + onnx_path = getOnnxPath(model_name, onnx_dir, opt=False) + onnx_opt_path = getOnnxPath(model_name, onnx_dir) + if force_engine_rebuild or not os.path.exists(onnx_opt_path): + if force_engine_rebuild or not os.path.exists(onnx_path): + logger.warning(f"Exporting model: {onnx_path}") + model = model_obj.get_model() + with torch.inference_mode(), torch.autocast("cuda"): + inputs = model_obj.get_sample_input(opt_batch_size, opt_image_height, opt_image_width) + torch.onnx.export(model, + inputs, + onnx_path, + export_params=True, + opset_version=onnx_opset, + do_constant_folding=True, + input_names=model_obj.get_input_names(), + output_names=model_obj.get_output_names(), + dynamic_axes=model_obj.get_dynamic_axes(), + ) + del model + torch.cuda.empty_cache() + gc.collect() + else: + logger.warning(f"Found cached model: {onnx_path}") + + # Optimize onnx + if force_engine_rebuild or not os.path.exists(onnx_opt_path): + logger.warning(f"Generating optimizing model: {onnx_opt_path}") + onnx_opt_graph = model_obj.optimize(onnx.load(onnx_path)) + onnx.save(onnx_opt_graph, onnx_opt_path) + else: + logger.warning(f"Found cached optimized model: {onnx_opt_path} ") + + # Build TensorRT engines + for model_name, model_obj in models.items(): + engine_path = getEnginePath(model_name, engine_dir) + engine = Engine(engine_path) + onnx_path = getOnnxPath(model_name, onnx_dir, opt=False) + onnx_opt_path = getOnnxPath(model_name, onnx_dir) + + if force_engine_rebuild or not os.path.exists(engine.engine_path): + engine.build(onnx_opt_path, + fp16=True, + input_profile=model_obj.get_input_profile( + opt_batch_size, opt_image_height, opt_image_width, + static_batch=static_batch, static_shape=static_shape + ), + enable_preview=enable_preview, + timing_cache=timing_cache, + workspace_size=max_workspace_size) + built_engines[model_name] = engine + + # Load and activate TensorRT engines + for model_name, model_obj in models.items(): + engine = built_engines[model_name] + engine.load() + engine.activate() + + return built_engines + +def runEngine(engine, feed_dict, stream): + return engine.infer(feed_dict, stream) From d95dda78617c839c2671d08a559ac33365522c8a Mon Sep 17 00:00:00 2001 From: Asfiya Baig Date: Wed, 5 Apr 2023 09:23:24 -0700 Subject: [PATCH 2/5] update installation command Signed-off-by: Asfiya Baig --- examples/community/stable_diffusion_tensorrt_txt2img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/community/stable_diffusion_tensorrt_txt2img.py b/examples/community/stable_diffusion_tensorrt_txt2img.py index d26959abbe13..d8eab3f6c9e8 100644 --- a/examples/community/stable_diffusion_tensorrt_txt2img.py +++ b/examples/community/stable_diffusion_tensorrt_txt2img.py @@ -33,7 +33,7 @@ ''' Installation instructions -python3 -m pip install --upgrade tensorrt polygraphy onnx-graphsurgeon -i https://pypi.ngc.nvidia.com +python3 -m pip install --upgrade tensorrt polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com python3 -m pip install onnxruntime ''' From 0c94f3b3d7f8789d56f8664d96e8a6774644c94f Mon Sep 17 00:00:00 2001 From: Asfiya Baig Date: Wed, 5 Apr 2023 10:48:39 -0700 Subject: [PATCH 3/5] update tensorrt installation Signed-off-by: Asfiya Baig --- examples/community/stable_diffusion_tensorrt_txt2img.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/community/stable_diffusion_tensorrt_txt2img.py b/examples/community/stable_diffusion_tensorrt_txt2img.py index d8eab3f6c9e8..b7fa0f26f728 100644 --- a/examples/community/stable_diffusion_tensorrt_txt2img.py +++ b/examples/community/stable_diffusion_tensorrt_txt2img.py @@ -33,7 +33,8 @@ ''' Installation instructions -python3 -m pip install --upgrade tensorrt polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com +python3 -m pip install --upgrade tensorrt +python3 -m pip install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com python3 -m pip install onnxruntime ''' From ea0337648e5f32079cdfaa72018dd79f3a3f5631 Mon Sep 17 00:00:00 2001 From: Asfiya Baig Date: Tue, 11 Apr 2023 09:31:31 -0700 Subject: [PATCH 4/5] changes 1. Update setting of cache directory 2. Address comments: merge utils and pipeline code. 3. Address comments: Add section in README Signed-off-by: Asfiya Baig --- examples/community/README.md | 33 +- .../stable_diffusion_tensorrt_txt2img.py | 909 ++++++++++++------ examples/community/tensorrt_utils.py | 355 ------- 3 files changed, 671 insertions(+), 626 deletions(-) delete mode 100644 examples/community/tensorrt_utils.py diff --git a/examples/community/README.md b/examples/community/README.md index 11da90764579..8b5b1743203d 100644 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -31,7 +31,7 @@ MagicMix | Diffusion Pipeline for semantic mixing of an image and a text prompt | UnCLIP Image Interpolation Pipeline | Diffusion Pipeline that allows passing two images/image_embeddings and produces images while interpolating between their image-embeddings | [UnCLIP Image Interpolation Pipeline](#unclip-image-interpolation-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) | | DDIM Noise Comparative Analysis Pipeline | Investigating how the diffusion models learn visual concepts from each noise level (which is a contribution of [P2 weighting (CVPR 2022)](https://arxiv.org/abs/2204.00227)) | [DDIM Noise Comparative Analysis Pipeline](#ddim-noise-comparative-analysis-pipeline) | - |[Aengus (Duc-Anh)](https://github.com/aengusng8) | | CLIP Guided Img2Img Stable Diffusion Pipeline | Doing CLIP guidance for image to image generation with Stable Diffusion | [CLIP Guided Img2Img Stable Diffusion](#clip-guided-img2img-stable-diffusion) | - | [Nipun Jindal](https://github.com/nipunjindal/) | - +| TensorRT Stable Diffusion Pipeline | Accelerates the Stable Diffusion Text2Image Pipeline using TensorRT | [TensorRT Stable Diffusion Pipeline](#tensorrt-text2image-stable-diffusion-pipeline) | - |[Asfiya Baig](https://github.com/asfiyab-nvidia) | To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly. @@ -1130,3 +1130,34 @@ Init Image Output Image ![img2img_clip_guidance](https://huggingface.co/datasets/njindal/images/resolve/main/clip_guided_img2img.jpg) + +### TensorRT Text2Image Stable Diffusion Pipeline + +The TensorRT Pipeline can be used to accelerate the Text2Image Stable Diffusion Inference run. + +NOTE: The ONNX conversions and TensorRT engine build may take up to 30 minutes. + +```python +import torch +from diffusers import DDIMScheduler +from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline + +# Use the DDIMScheduler scheduler here instead +scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1", + subfolder="scheduler") + +pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", + custom_pipeline="stable_diffusion_tensorrt_txt2img", + revision='fp16', + torch_dtype=torch.float16, + scheduler=scheduler,) + +# re-use cached folder to save ONNX models and TensorRT Engines +pipe.set_cached_folder("stabilityai/stable-diffusion-2-1", revision='fp16',) + +pipe = pipe.to("cuda") + +prompt = "a beautiful photograph of Mt. Fuji during cherry blossom" +image = pipe(prompt).images[0] +image.save('tensorrt_mt_fuji.png') +``` diff --git a/examples/community/stable_diffusion_tensorrt_txt2img.py b/examples/community/stable_diffusion_tensorrt_txt2img.py index b7fa0f26f728..a81a4a4c0fbf 100644 --- a/examples/community/stable_diffusion_tensorrt_txt2img.py +++ b/examples/community/stable_diffusion_tensorrt_txt2img.py @@ -16,54 +16,551 @@ # limitations under the License. import os +import gc import numpy as np +from copy import copy +from collections import OrderedDict from typing import List, Optional, Union +import onnx +from onnx import shape_inference import torch import tensorrt as trt +import onnx_graphsurgeon as gs from polygraphy import cuda +from polygraphy.backend.onnx.loader import fold_constants +from polygraphy.backend.common import bytes_from_path +from polygraphy.backend.trt import CreateConfig, Profile +from polygraphy.backend.trt import engine_from_bytes, engine_from_network, network_from_onnx_path, save_engine +from polygraphy.backend.trt import util as trt_util from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer - -from tensorrt_utils import TRT_LOGGER, build_engines, device_view, runEngine, Optimizer, BaseModel from huggingface_hub import snapshot_download + from diffusers.utils import DIFFUSERS_CACHE, logging from diffusers.schedulers import DDIMScheduler from diffusers.models import AutoencoderKL, UNet2DConditionModel -from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline, StableDiffusionPipelineOutput, StableDiffusionSafetyChecker +from diffusers.pipelines.stable_diffusion import ( + StableDiffusionPipeline, + StableDiffusionPipelineOutput, + StableDiffusionSafetyChecker, +) -''' +""" Installation instructions python3 -m pip install --upgrade tensorrt python3 -m pip install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com python3 -m pip install onnxruntime -''' - -EXAMPLE_DOC_STRING = """ - Examples: - ```python3 - >>> import torch - >>> from diffusers import DDIMScheduler - - >>> # Use the DDIMScheduler scheduler here instead - >>> scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1", - subfolder="scheduler") - - >>> pipe = TensorRTStableDiffusionPipeline("stabilityai/stable-diffusion-2-1", - revision='fp16', - torch_dtype=torch.float16, - scheduler=scheduler - ) - - >>> pipe = pipe.to("cuda") - - >>> prompt = "a beautiful photograph of Mt. Fuji during cherry blossom" - >>> image = pipe(prompt).images[0] - >>> image.save('tensorrt_mt_fuji.png') - ``` """ +TRT_LOGGER = trt.Logger(trt.Logger.ERROR) logger = logging.get_logger(__name__) # pylint: disable=invalid-name +# Map of numpy dtype -> torch dtype +numpy_to_torch_dtype_dict = { + np.uint8: torch.uint8, + np.int8: torch.int8, + np.int16: torch.int16, + np.int32: torch.int32, + np.int64: torch.int64, + np.float16: torch.float16, + np.float32: torch.float32, + np.float64: torch.float64, + np.complex64: torch.complex64, + np.complex128: torch.complex128, +} +if np.version.full_version >= "1.24.0": + numpy_to_torch_dtype_dict[np.bool_] = torch.bool +else: + numpy_to_torch_dtype_dict[np.bool] = torch.bool + +# Map of torch dtype -> numpy dtype +torch_to_numpy_dtype_dict = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()} + + +def device_view(t): + return cuda.DeviceView(ptr=t.data_ptr(), shape=t.shape, dtype=torch_to_numpy_dtype_dict[t.dtype]) + + +class Engine: + def __init__(self, engine_path): + self.engine_path = engine_path + self.engine = None + self.context = None + self.buffers = OrderedDict() + self.tensors = OrderedDict() + + def __del__(self): + [buf.free() for buf in self.buffers.values() if isinstance(buf, cuda.DeviceArray)] + del self.engine + del self.context + del self.buffers + del self.tensors + + def build( + self, + onnx_path, + fp16, + input_profile=None, + enable_preview=False, + enable_all_tactics=False, + timing_cache=None, + workspace_size=0, + ): + logger.warning(f"Building TensorRT engine for {onnx_path}: {self.engine_path}") + p = Profile() + if input_profile: + for name, dims in input_profile.items(): + assert len(dims) == 3 + p.add(name, min=dims[0], opt=dims[1], max=dims[2]) + + config_kwargs = {} + + config_kwargs["preview_features"] = [trt.PreviewFeature.DISABLE_EXTERNAL_TACTIC_SOURCES_FOR_CORE_0805] + if enable_preview: + # Faster dynamic shapes made optional since it increases engine build time. + config_kwargs["preview_features"].append(trt.PreviewFeature.FASTER_DYNAMIC_SHAPES_0805) + if workspace_size > 0: + config_kwargs["memory_pool_limits"] = {trt.MemoryPoolType.WORKSPACE: workspace_size} + if not enable_all_tactics: + config_kwargs["tactic_sources"] = [] + + engine = engine_from_network( + network_from_onnx_path(onnx_path), + config=CreateConfig(fp16=fp16, profiles=[p], load_timing_cache=timing_cache, **config_kwargs), + save_timing_cache=timing_cache, + ) + save_engine(engine, path=self.engine_path) + + def load(self): + logger.warning(f"Loading TensorRT engine: {self.engine_path}") + self.engine = engine_from_bytes(bytes_from_path(self.engine_path)) + + def activate(self): + self.context = self.engine.create_execution_context() + + def allocate_buffers(self, shape_dict=None, device="cuda"): + for idx in range(trt_util.get_bindings_per_profile(self.engine)): + binding = self.engine[idx] + if shape_dict and binding in shape_dict: + shape = shape_dict[binding] + else: + shape = self.engine.get_binding_shape(binding) + dtype = trt.nptype(self.engine.get_binding_dtype(binding)) + if self.engine.binding_is_input(binding): + self.context.set_binding_shape(idx, shape) + tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device) + self.tensors[binding] = tensor + self.buffers[binding] = cuda.DeviceView(ptr=tensor.data_ptr(), shape=shape, dtype=dtype) + + def infer(self, feed_dict, stream): + start_binding, end_binding = trt_util.get_active_profile_bindings(self.context) + # shallow copy of ordered dict + device_buffers = copy(self.buffers) + for name, buf in feed_dict.items(): + assert isinstance(buf, cuda.DeviceView) + device_buffers[name] = buf + bindings = [0] * start_binding + [buf.ptr for buf in device_buffers.values()] + noerror = self.context.execute_async_v2(bindings=bindings, stream_handle=stream.ptr) + if not noerror: + raise ValueError(f"ERROR: inference failed.") + + return self.tensors + + +class Optimizer: + def __init__(self, onnx_graph): + self.graph = gs.import_onnx(onnx_graph) + + def cleanup(self, return_onnx=False): + self.graph.cleanup().toposort() + if return_onnx: + return gs.export_onnx(self.graph) + + def select_outputs(self, keep, names=None): + self.graph.outputs = [self.graph.outputs[o] for o in keep] + if names: + for i, name in enumerate(names): + self.graph.outputs[i].name = name + + def fold_constants(self, return_onnx=False): + onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True) + self.graph = gs.import_onnx(onnx_graph) + if return_onnx: + return onnx_graph + + def infer_shapes(self, return_onnx=False): + onnx_graph = gs.export_onnx(self.graph) + if onnx_graph.ByteSize() > 2147483648: + raise TypeError("ERROR: model size exceeds supported 2GB limit") + else: + onnx_graph = shape_inference.infer_shapes(onnx_graph) + + self.graph = gs.import_onnx(onnx_graph) + if return_onnx: + return onnx_graph + + +class BaseModel: + def __init__(self, model, fp16=False, device="cuda", max_batch_size=16, embedding_dim=768, text_maxlen=77): + self.model = model + self.name = "SD Model" + self.fp16 = fp16 + self.device = device + + self.min_batch = 1 + self.max_batch = max_batch_size + self.min_image_shape = 256 # min image resolution: 256x256 + self.max_image_shape = 1024 # max image resolution: 1024x1024 + self.min_latent_shape = self.min_image_shape // 8 + self.max_latent_shape = self.max_image_shape // 8 + + self.embedding_dim = embedding_dim + self.text_maxlen = text_maxlen + + def get_model(self): + return self.model + + def get_input_names(self): + pass + + def get_output_names(self): + pass + + def get_dynamic_axes(self): + return None + + def get_sample_input(self, batch_size, image_height, image_width): + pass + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): + return None + + def get_shape_dict(self, batch_size, image_height, image_width): + return None + + def optimize(self, onnx_graph): + opt = Optimizer(onnx_graph) + opt.cleanup() + opt.fold_constants() + opt.infer_shapes() + onnx_opt_graph = opt.cleanup(return_onnx=True) + return onnx_opt_graph + + def check_dims(self, batch_size, image_height, image_width): + assert batch_size >= self.min_batch and batch_size <= self.max_batch + assert image_height % 8 == 0 or image_width % 8 == 0 + latent_height = image_height // 8 + latent_width = image_width // 8 + assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape + assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape + return (latent_height, latent_width) + + def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_shape): + min_batch = batch_size if static_batch else self.min_batch + max_batch = batch_size if static_batch else self.max_batch + latent_height = image_height // 8 + latent_width = image_width // 8 + min_image_height = image_height if static_shape else self.min_image_shape + max_image_height = image_height if static_shape else self.max_image_shape + min_image_width = image_width if static_shape else self.min_image_shape + max_image_width = image_width if static_shape else self.max_image_shape + min_latent_height = latent_height if static_shape else self.min_latent_shape + max_latent_height = latent_height if static_shape else self.max_latent_shape + min_latent_width = latent_width if static_shape else self.min_latent_shape + max_latent_width = latent_width if static_shape else self.max_latent_shape + return ( + min_batch, + max_batch, + min_image_height, + max_image_height, + min_image_width, + max_image_width, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) + + +def getOnnxPath(model_name, onnx_dir, opt=True): + return os.path.join(onnx_dir, model_name + (".opt" if opt else "") + ".onnx") + + +def getEnginePath(model_name, engine_dir): + return os.path.join(engine_dir, model_name + ".plan") + + +def build_engines( + models: dict, + engine_dir, + onnx_dir, + onnx_opset, + opt_image_height, + opt_image_width, + opt_batch_size=1, + force_engine_rebuild=False, + static_batch=False, + static_shape=True, + enable_preview=False, + enable_all_tactics=False, + timing_cache=None, + max_workspace_size=0, +): + built_engines = {} + if not os.path.isdir(onnx_dir): + os.makedirs(onnx_dir) + if not os.path.isdir(engine_dir): + os.makedirs(engine_dir) + + # Export models to ONNX + for model_name, model_obj in models.items(): + engine_path = getEnginePath(model_name, engine_dir) + if force_engine_rebuild or not os.path.exists(engine_path): + logger.warning("Building Engines...") + logger.warning("Engine build can take a while to complete") + onnx_path = getOnnxPath(model_name, onnx_dir, opt=False) + onnx_opt_path = getOnnxPath(model_name, onnx_dir) + if force_engine_rebuild or not os.path.exists(onnx_opt_path): + if force_engine_rebuild or not os.path.exists(onnx_path): + logger.warning(f"Exporting model: {onnx_path}") + model = model_obj.get_model() + with torch.inference_mode(), torch.autocast("cuda"): + inputs = model_obj.get_sample_input(opt_batch_size, opt_image_height, opt_image_width) + torch.onnx.export( + model, + inputs, + onnx_path, + export_params=True, + opset_version=onnx_opset, + do_constant_folding=True, + input_names=model_obj.get_input_names(), + output_names=model_obj.get_output_names(), + dynamic_axes=model_obj.get_dynamic_axes(), + ) + del model + torch.cuda.empty_cache() + gc.collect() + else: + logger.warning(f"Found cached model: {onnx_path}") + + # Optimize onnx + if force_engine_rebuild or not os.path.exists(onnx_opt_path): + logger.warning(f"Generating optimizing model: {onnx_opt_path}") + onnx_opt_graph = model_obj.optimize(onnx.load(onnx_path)) + onnx.save(onnx_opt_graph, onnx_opt_path) + else: + logger.warning(f"Found cached optimized model: {onnx_opt_path} ") + + # Build TensorRT engines + for model_name, model_obj in models.items(): + engine_path = getEnginePath(model_name, engine_dir) + engine = Engine(engine_path) + onnx_path = getOnnxPath(model_name, onnx_dir, opt=False) + onnx_opt_path = getOnnxPath(model_name, onnx_dir) + + if force_engine_rebuild or not os.path.exists(engine.engine_path): + engine.build( + onnx_opt_path, + fp16=True, + input_profile=model_obj.get_input_profile( + opt_batch_size, + opt_image_height, + opt_image_width, + static_batch=static_batch, + static_shape=static_shape, + ), + enable_preview=enable_preview, + timing_cache=timing_cache, + workspace_size=max_workspace_size, + ) + built_engines[model_name] = engine + + # Load and activate TensorRT engines + for model_name, model_obj in models.items(): + engine = built_engines[model_name] + engine.load() + engine.activate() + + return built_engines + + +def runEngine(engine, feed_dict, stream): + return engine.infer(feed_dict, stream) + + +class CLIP(BaseModel): + def __init__(self, model, device, max_batch_size, embedding_dim): + super(CLIP, self).__init__( + model=model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim + ) + self.name = "CLIP" + + def get_input_names(self): + return ["input_ids"] + + def get_output_names(self): + return ["text_embeddings", "pooler_output"] + + def get_dynamic_axes(self): + return {"input_ids": {0: "B"}, "text_embeddings": {0: "B"}} + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): + self.check_dims(batch_size, image_height, image_width) + min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims( + batch_size, image_height, image_width, static_batch, static_shape + ) + return { + "input_ids": [(min_batch, self.text_maxlen), (batch_size, self.text_maxlen), (max_batch, self.text_maxlen)] + } + + def get_shape_dict(self, batch_size, image_height, image_width): + self.check_dims(batch_size, image_height, image_width) + return { + "input_ids": (batch_size, self.text_maxlen), + "text_embeddings": (batch_size, self.text_maxlen, self.embedding_dim), + } + + def get_sample_input(self, batch_size, image_height, image_width): + self.check_dims(batch_size, image_height, image_width) + return torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device) + + def optimize(self, onnx_graph): + opt = Optimizer(onnx_graph) + opt.select_outputs([0]) # delete graph output#1 + opt.cleanup() + opt.fold_constants() + opt.infer_shapes() + opt.select_outputs([0], names=["text_embeddings"]) # rename network output + opt_onnx_graph = opt.cleanup(return_onnx=True) + return opt_onnx_graph + + +def make_CLIP(model, device, max_batch_size, embedding_dim, inpaint=False): + return CLIP(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim) + + +class UNet(BaseModel): + def __init__( + self, model, fp16=False, device="cuda", max_batch_size=16, embedding_dim=768, text_maxlen=77, unet_dim=4 + ): + super(UNet, self).__init__( + model=model, + fp16=fp16, + device=device, + max_batch_size=max_batch_size, + embedding_dim=embedding_dim, + text_maxlen=text_maxlen, + ) + self.unet_dim = unet_dim + self.name = "UNet" + + def get_input_names(self): + return ["sample", "timestep", "encoder_hidden_states"] + + def get_output_names(self): + return ["latent"] + + def get_dynamic_axes(self): + return { + "sample": {0: "2B", 2: "H", 3: "W"}, + "encoder_hidden_states": {0: "2B"}, + "latent": {0: "2B", 2: "H", 3: "W"}, + } + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + min_batch, max_batch, _, _, _, _, min_latent_height, max_latent_height, min_latent_width, max_latent_width = self.get_minmax_dims( + batch_size, image_height, image_width, static_batch, static_shape + ) + return { + "sample": [ + (2 * min_batch, self.unet_dim, min_latent_height, min_latent_width), + (2 * batch_size, self.unet_dim, latent_height, latent_width), + (2 * max_batch, self.unet_dim, max_latent_height, max_latent_width), + ], + "encoder_hidden_states": [ + (2 * min_batch, self.text_maxlen, self.embedding_dim), + (2 * batch_size, self.text_maxlen, self.embedding_dim), + (2 * max_batch, self.text_maxlen, self.embedding_dim), + ], + } + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return { + "sample": (2 * batch_size, self.unet_dim, latent_height, latent_width), + "encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim), + "latent": (2 * batch_size, 4, latent_height, latent_width), + } + + def get_sample_input(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + dtype = torch.float16 if self.fp16 else torch.float32 + return ( + torch.randn( + 2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device + ), + torch.tensor([1.0], dtype=torch.float32, device=self.device), + torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device), + ) + + +def make_UNet(model, device, max_batch_size, embedding_dim, inpaint=False): + return UNet( + model, + fp16=True, + device=device, + max_batch_size=max_batch_size, + embedding_dim=embedding_dim, + unet_dim=(9 if inpaint else 4), + ) + + +class VAE(BaseModel): + def __init__(self, model, device, max_batch_size, embedding_dim): + super(VAE, self).__init__( + model=model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim + ) + self.name = "VAE decoder" + + def get_input_names(self): + return ["latent"] + + def get_output_names(self): + return ["images"] + + def get_dynamic_axes(self): + return {"latent": {0: "B", 2: "H", 3: "W"}, "images": {0: "B", 2: "8H", 3: "8W"}} + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + min_batch, max_batch, _, _, _, _, min_latent_height, max_latent_height, min_latent_width, max_latent_width = self.get_minmax_dims( + batch_size, image_height, image_width, static_batch, static_shape + ) + return { + "latent": [ + (min_batch, 4, min_latent_height, min_latent_width), + (batch_size, 4, latent_height, latent_width), + (max_batch, 4, max_latent_height, max_latent_width), + ] + } + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return { + "latent": (batch_size, 4, latent_height, latent_width), + "images": (batch_size, 3, image_height, image_width), + } + + def get_sample_input(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return torch.randn(batch_size, 4, latent_height, latent_width, dtype=torch.float32, device=self.device) + + +def make_VAE(model, device, max_batch_size, embedding_dim, inpaint=False): + return VAE(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim) + + class TensorRTStableDiffusionPipeline(StableDiffusionPipeline): r""" Pipeline for text-to-image generation using TensorRT accelerated Stable Diffusion. @@ -102,19 +599,21 @@ def __init__( safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, requires_safety_checker: bool = True, - stages=['clip','unet','vae'], + stages=["clip", "unet", "vae"], image_height: int = 768, image_width: int = 768, max_batch_size: int = 16, # ONNX export parameters onnx_opset: int = 17, - onnx_dir: str = 'onnx', + onnx_dir: str = "onnx", # TensorRT engine build parameters - engine_dir: str = 'engine', + engine_dir: str = "engine", force_engine_rebuild: bool = False, - timing_cache: str = 'timing_cache' + timing_cache: str = "timing_cache", ): - super().__init__(vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker) + super().__init__( + vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker + ) self.vae.forward = self.vae.decode @@ -135,21 +634,48 @@ def __init__( if self.build_dynamic_shape or self.image_height > 512 or self.image_width > 512: self.max_batch_size = 4 - self.stream = None # loaded in loadResources() - self.models = {} # loaded in __loadModels() - self.engine = {} # loaded in build_engines() + self.stream = None # loaded in loadResources() + self.models = {} # loaded in __loadModels() + self.engine = {} # loaded in build_engines() def __loadModels(self): # Load pipeline models self.embedding_dim = self.text_encoder.config.hidden_size - models_args = {'device': self.torch_device, 'max_batch_size': self.max_batch_size, \ - 'embedding_dim': self.embedding_dim, 'inpaint': self.inpaint} - if 'clip' in self.stages: - self.models['clip'] = make_CLIP(self.text_encoder, **models_args) - if 'unet' in self.stages: - self.models['unet'] = make_UNet(self.unet, **models_args) - if 'vae' in self.stages: - self.models['vae'] = make_VAE(self.vae, **models_args) + models_args = { + "device": self.torch_device, + "max_batch_size": self.max_batch_size, + "embedding_dim": self.embedding_dim, + "inpaint": self.inpaint, + } + if "clip" in self.stages: + self.models["clip"] = make_CLIP(self.text_encoder, **models_args) + if "unet" in self.stages: + self.models["unet"] = make_UNet(self.unet, **models_args) + if "vae" in self.stages: + self.models["vae"] = make_VAE(self.vae, **models_args) + + @classmethod + def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + + cls.cached_folder = ( + pretrained_model_name_or_path + if os.path.isdir(pretrained_model_name_or_path) + else snapshot_download( + pretrained_model_name_or_path, + cache_dir=cache_dir, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + ) + ) def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings: bool = False): super().to(torch_device, silence_dtype_warnings) @@ -166,45 +692,23 @@ def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dt self.__loadModels() # build engines - self.engine = build_engines(self.models, self.engine_dir, self.onnx_dir, self.onnx_opset, \ - opt_image_height=self.image_height, opt_image_width=self.image_width, \ - force_engine_rebuild=self.force_engine_rebuild, static_batch=self.build_static_batch, \ - static_shape=not self.build_dynamic_shape, enable_preview=self.build_preview_features, \ - timing_cache=self.timing_cache) + self.engine = build_engines( + self.models, + self.engine_dir, + self.onnx_dir, + self.onnx_opset, + opt_image_height=self.image_height, + opt_image_width=self.image_width, + force_engine_rebuild=self.force_engine_rebuild, + static_batch=self.build_static_batch, + static_shape=not self.build_dynamic_shape, + enable_preview=self.build_preview_features, + timing_cache=self.timing_cache, + ) return self - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): - pipe = super().from_pretrained(pretrained_model_name_or_path, **kwargs) - - cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) - resume_download = kwargs.pop("resume_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", False) - use_auth_token = kwargs.pop("use_auth_token", None) - revision = kwargs.pop("revision", None) - - cls.cached_folder = ( - pretrained_model_name_or_path - if os.path.isdir(pretrained_model_name_or_path) - else snapshot_download( - pretrained_model_name_or_path, - cache_dir=cache_dir, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - ) - ) - return pipe - - def __encode_prompt( - self, - prompt, - negative_prompt - ): + def __encode_prompt(self, prompt, negative_prompt): r""" Encodes the prompt into text encoder hidden states. @@ -217,36 +721,49 @@ def __encode_prompt( Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). """ # Tokenize prompt - text_input_ids = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ).input_ids.type(torch.int32).to(self.torch_device) + text_input_ids = ( + self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + .input_ids.type(torch.int32) + .to(self.torch_device) + ) text_input_ids_inp = device_view(text_input_ids) # NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt - text_embeddings = runEngine(self.engine['clip'], {"input_ids": text_input_ids_inp}, self.stream)['text_embeddings'].clone() + text_embeddings = runEngine(self.engine["clip"], {"input_ids": text_input_ids_inp}, self.stream)[ + "text_embeddings" + ].clone() # Tokenize negative prompt - uncond_input_ids = self.tokenizer( - negative_prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ).input_ids.type(torch.int32).to(self.torch_device) + uncond_input_ids = ( + self.tokenizer( + negative_prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + .input_ids.type(torch.int32) + .to(self.torch_device) + ) uncond_input_ids_inp = device_view(uncond_input_ids) - uncond_embeddings = runEngine(self.engine['clip'], {"input_ids": uncond_input_ids_inp}, self.stream)['text_embeddings'] + uncond_embeddings = runEngine(self.engine["clip"], {"input_ids": uncond_input_ids_inp}, self.stream)[ + "text_embeddings" + ] # Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to(dtype=torch.float16) return text_embeddings - - def __denoise_latent(self, latents, text_embeddings, timesteps=None, step_offset=0, mask=None, masked_image_latents=None): + def __denoise_latent( + self, latents, text_embeddings, timesteps=None, step_offset=0, mask=None, masked_image_latents=None + ): if not isinstance(timesteps, torch.Tensor): timesteps = self.scheduler.timesteps for step_index, timestep in enumerate(timesteps): @@ -263,9 +780,11 @@ def __denoise_latent(self, latents, text_embeddings, timesteps=None, step_offset sample_inp = device_view(latent_model_input) timestep_inp = device_view(timestep_float) embeddings_inp = device_view(text_embeddings) - noise_pred = runEngine(self.engine['unet'], \ - {"sample": sample_inp, "timestep": timestep_inp, "encoder_hidden_states": embeddings_inp}, \ - self.stream)['latent'] + noise_pred = runEngine( + self.engine["unet"], + {"sample": sample_inp, "timestep": timestep_inp, "encoder_hidden_states": embeddings_inp}, + self.stream, + )["latent"] # Perform guidance noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) @@ -273,11 +792,11 @@ def __denoise_latent(self, latents, text_embeddings, timesteps=None, step_offset latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample - latents = 1. / 0.18215 * latents + latents = 1.0 / 0.18215 * latents return latents def __decode_latent(self, latents): - images = runEngine(self.engine['vae'], {"latent": device_view(latents)}, self.stream)['images'] + images = runEngine(self.engine["vae"], {"latent": device_view(latents)}, self.stream)["images"] images = (images / 2 + 0.5).clamp(0, 1) return images.cpu().permute(0, 2, 3, 1).float().numpy() @@ -286,7 +805,9 @@ def __loadResources(self, image_height, image_width, batch_size): # Allocate buffers for TensorRT engine bindings for model_name, obj in self.models.items(): - self.engine[model_name].allocate_buffers(shape_dict=obj.get_shape_dict(batch_size, image_height, image_width), device=self.torch_device) + self.engine[model_name].allocate_buffers( + shape_dict=obj.get_shape_dict(batch_size, image_height, image_width), device=self.torch_device + ) @torch.no_grad() def __call__( @@ -347,24 +868,28 @@ def __call__( assert len(prompt) == len(negative_prompt) if batch_size > self.max_batch_size: - raise ValueError(f"Batch size {len(prompt)} is larger than allowed {self.max_batch_size}. If dynamic shape is used, then maximum batch size is 4") - + raise ValueError( + f"Batch size {len(prompt)} is larger than allowed {self.max_batch_size}. If dynamic shape is used, then maximum batch size is 4" + ) + # load resources self.__loadResources(self.image_height, self.image_width, batch_size) - + with torch.inference_mode(), torch.autocast("cuda"), trt.Runtime(TRT_LOGGER): # CLIP text encoder text_embeddings = self.__encode_prompt(prompt, negative_prompt) # Pre-initialize latents num_channels_latents = self.unet.in_channels - latents = self.prepare_latents(batch_size, - num_channels_latents, - self.image_height, - self.image_width, - torch.float32, - self.torch_device, - generator) + latents = self.prepare_latents( + batch_size, + num_channels_latents, + self.image_height, + self.image_width, + torch.float32, + self.torch_device, + generator, + ) # UNet denoiser latents = self.__denoise_latent(latents, text_embeddings) @@ -375,159 +900,3 @@ def __call__( images, has_nsfw_concept = self.run_safety_checker(images, self.torch_device, text_embeddings.dtype) images = self.numpy_to_pil(images) return StableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) - - -class CLIP(BaseModel): - def __init__(self, - model, - device, - max_batch_size, - embedding_dim - ): - super(CLIP, self).__init__(model=model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim) - self.name = "CLIP" - - def get_input_names(self): - return ['input_ids'] - - def get_output_names(self): - return ['text_embeddings', 'pooler_output'] - - def get_dynamic_axes(self): - return { - 'input_ids': {0: 'B'}, - 'text_embeddings': {0: 'B'} - } - - def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): - self.check_dims(batch_size, image_height, image_width) - min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape) - return { - 'input_ids': [(min_batch, self.text_maxlen), (batch_size, self.text_maxlen), (max_batch, self.text_maxlen)] - } - - def get_shape_dict(self, batch_size, image_height, image_width): - self.check_dims(batch_size, image_height, image_width) - return { - 'input_ids': (batch_size, self.text_maxlen), - 'text_embeddings': (batch_size, self.text_maxlen, self.embedding_dim) - } - - def get_sample_input(self, batch_size, image_height, image_width): - self.check_dims(batch_size, image_height, image_width) - return torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device) - - def optimize(self, onnx_graph): - opt = Optimizer(onnx_graph) - opt.select_outputs([0]) # delete graph output#1 - opt.cleanup() - opt.fold_constants() - opt.infer_shapes() - opt.select_outputs([0], names=['text_embeddings']) # rename network output - opt_onnx_graph = opt.cleanup(return_onnx=True) - return opt_onnx_graph - -def make_CLIP(model, device, max_batch_size, embedding_dim, inpaint=False): - return CLIP(model, device=device, max_batch_size=max_batch_size, \ - embedding_dim=embedding_dim) - -class UNet(BaseModel): - def __init__(self, - model, - fp16=False, - device='cuda', - max_batch_size=16, - embedding_dim=768, - text_maxlen=77, - unet_dim=4 - ): - super(UNet, self).__init__(model=model, fp16=fp16, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim, text_maxlen=text_maxlen) - self.unet_dim = unet_dim - self.name = "UNet" - - def get_input_names(self): - return ['sample', 'timestep', 'encoder_hidden_states'] - - def get_output_names(self): - return ['latent'] - - def get_dynamic_axes(self): - return { - 'sample': {0: '2B', 2: 'H', 3: 'W'}, - 'encoder_hidden_states': {0: '2B'}, - 'latent': {0: '2B', 2: 'H', 3: 'W'} - } - - def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): - latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - min_batch, max_batch, _, _, _, _, min_latent_height, max_latent_height, min_latent_width, max_latent_width = \ - self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape) - return { - 'sample': [(2*min_batch, self.unet_dim, min_latent_height, min_latent_width), (2*batch_size, self.unet_dim, latent_height, latent_width), (2*max_batch, self.unet_dim, max_latent_height, max_latent_width)], - 'encoder_hidden_states': [(2*min_batch, self.text_maxlen, self.embedding_dim), (2*batch_size, self.text_maxlen, self.embedding_dim), (2*max_batch, self.text_maxlen, self.embedding_dim)] - } - - def get_shape_dict(self, batch_size, image_height, image_width): - latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - return { - 'sample': (2*batch_size, self.unet_dim, latent_height, latent_width), - 'encoder_hidden_states': (2*batch_size, self.text_maxlen, self.embedding_dim), - 'latent': (2*batch_size, 4, latent_height, latent_width) - } - - def get_sample_input(self, batch_size, image_height, image_width): - latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - dtype = torch.float16 if self.fp16 else torch.float32 - return ( - torch.randn(2*batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device), - torch.tensor([1.], dtype=torch.float32, device=self.device), - torch.randn(2*batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device) - ) - -def make_UNet(model, device, max_batch_size, embedding_dim, inpaint=False): - return UNet(model, fp16=True, device=device, max_batch_size=max_batch_size, \ - embedding_dim=embedding_dim, unet_dim=(9 if inpaint else 4)) - -class VAE(BaseModel): - def __init__(self, - model, - device, - max_batch_size, - embedding_dim - ): - super(VAE, self).__init__(model=model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim) - self.name = "VAE decoder" - - def get_input_names(self): - return ['latent'] - - def get_output_names(self): - return ['images'] - - def get_dynamic_axes(self): - return { - 'latent': {0: 'B', 2: 'H', 3: 'W'}, - 'images': {0: 'B', 2: '8H', 3: '8W'} - } - - def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): - latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - min_batch, max_batch, _, _, _, _, min_latent_height, max_latent_height, min_latent_width, max_latent_width = \ - self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape) - return { - 'latent': [(min_batch, 4, min_latent_height, min_latent_width), (batch_size, 4, latent_height, latent_width), (max_batch, 4, max_latent_height, max_latent_width)] - } - - def get_shape_dict(self, batch_size, image_height, image_width): - latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - return { - 'latent': (batch_size, 4, latent_height, latent_width), - 'images': (batch_size, 3, image_height, image_width) - } - - def get_sample_input(self, batch_size, image_height, image_width): - latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - return torch.randn(batch_size, 4, latent_height, latent_width, dtype=torch.float32, device=self.device) - -def make_VAE(model, device, max_batch_size, embedding_dim, inpaint=False): - return VAE(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim) diff --git a/examples/community/tensorrt_utils.py b/examples/community/tensorrt_utils.py deleted file mode 100644 index 78118a716407..000000000000 --- a/examples/community/tensorrt_utils.py +++ /dev/null @@ -1,355 +0,0 @@ -# -# Copyright 2023 The HuggingFace Inc. team. -# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import gc -from collections import OrderedDict -import onnx -import torch -import numpy as np -from copy import copy -import tensorrt as trt -import onnx_graphsurgeon as gs -from onnx import shape_inference -from polygraphy.backend.onnx.loader import fold_constants -from polygraphy.backend.common import bytes_from_path -from polygraphy.backend.trt import CreateConfig, Profile -from polygraphy.backend.trt import engine_from_bytes, engine_from_network, network_from_onnx_path, save_engine -from polygraphy.backend.trt import util as trt_util -from polygraphy import cuda -from diffusers.utils import logging - -TRT_LOGGER = trt.Logger(trt.Logger.ERROR) -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -# Map of numpy dtype -> torch dtype -numpy_to_torch_dtype_dict = { - np.uint8 : torch.uint8, - np.int8 : torch.int8, - np.int16 : torch.int16, - np.int32 : torch.int32, - np.int64 : torch.int64, - np.float16 : torch.float16, - np.float32 : torch.float32, - np.float64 : torch.float64, - np.complex64 : torch.complex64, - np.complex128 : torch.complex128 -} -if np.version.full_version >= "1.24.0": - numpy_to_torch_dtype_dict[np.bool_] = torch.bool -else: - numpy_to_torch_dtype_dict[np.bool] = torch.bool - -# Map of torch dtype -> numpy dtype -torch_to_numpy_dtype_dict = {value : key for (key, value) in numpy_to_torch_dtype_dict.items()} - -def device_view(t): - return cuda.DeviceView(ptr=t.data_ptr(), shape=t.shape, dtype=torch_to_numpy_dtype_dict[t.dtype]) - -class Engine(): - def __init__( - self, - engine_path, - ): - self.engine_path = engine_path - self.engine = None - self.context = None - self.buffers = OrderedDict() - self.tensors = OrderedDict() - - def __del__(self): - [buf.free() for buf in self.buffers.values() if isinstance(buf, cuda.DeviceArray) ] - del self.engine - del self.context - del self.buffers - del self.tensors - - def build(self, onnx_path, fp16, input_profile=None, enable_preview=False, enable_all_tactics=False, timing_cache=None, workspace_size=0): - logger.warning(f"Building TensorRT engine for {onnx_path}: {self.engine_path}") - p = Profile() - if input_profile: - for name, dims in input_profile.items(): - assert len(dims) == 3 - p.add(name, min=dims[0], opt=dims[1], max=dims[2]) - - config_kwargs = {} - - config_kwargs['preview_features'] = [trt.PreviewFeature.DISABLE_EXTERNAL_TACTIC_SOURCES_FOR_CORE_0805] - if enable_preview: - # Faster dynamic shapes made optional since it increases engine build time. - config_kwargs['preview_features'].append(trt.PreviewFeature.FASTER_DYNAMIC_SHAPES_0805) - if workspace_size > 0: - config_kwargs['memory_pool_limits'] = {trt.MemoryPoolType.WORKSPACE: workspace_size} - if not enable_all_tactics: - config_kwargs['tactic_sources'] = [] - - engine = engine_from_network( - network_from_onnx_path(onnx_path), - config=CreateConfig(fp16=fp16, - profiles=[p], - load_timing_cache=timing_cache, - **config_kwargs - ), - save_timing_cache=timing_cache - ) - save_engine(engine, path=self.engine_path) - - def load(self): - logger.warning(f"Loading TensorRT engine: {self.engine_path}") - self.engine = engine_from_bytes(bytes_from_path(self.engine_path)) - - def activate(self): - self.context = self.engine.create_execution_context() - - def allocate_buffers(self, shape_dict=None, device='cuda'): - for idx in range(trt_util.get_bindings_per_profile(self.engine)): - binding = self.engine[idx] - if shape_dict and binding in shape_dict: - shape = shape_dict[binding] - else: - shape = self.engine.get_binding_shape(binding) - dtype = trt.nptype(self.engine.get_binding_dtype(binding)) - if self.engine.binding_is_input(binding): - self.context.set_binding_shape(idx, shape) - tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device) - self.tensors[binding] = tensor - self.buffers[binding] = cuda.DeviceView(ptr=tensor.data_ptr(), shape=shape, dtype=dtype) - - def infer(self, feed_dict, stream): - start_binding, end_binding = trt_util.get_active_profile_bindings(self.context) - # shallow copy of ordered dict - device_buffers = copy(self.buffers) - for name, buf in feed_dict.items(): - assert isinstance(buf, cuda.DeviceView) - device_buffers[name] = buf - bindings = [0] * start_binding + [buf.ptr for buf in device_buffers.values()] - noerror = self.context.execute_async_v2(bindings=bindings, stream_handle=stream.ptr) - if not noerror: - raise ValueError(f"ERROR: inference failed.") - - return self.tensors - -class Optimizer(): - def __init__( - self, - onnx_graph, - ): - self.graph = gs.import_onnx(onnx_graph) - - def cleanup(self, return_onnx=False): - self.graph.cleanup().toposort() - if return_onnx: - return gs.export_onnx(self.graph) - - def select_outputs(self, keep, names=None): - self.graph.outputs = [self.graph.outputs[o] for o in keep] - if names: - for i, name in enumerate(names): - self.graph.outputs[i].name = name - - def fold_constants(self, return_onnx=False): - onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True) - self.graph = gs.import_onnx(onnx_graph) - if return_onnx: - return onnx_graph - - def infer_shapes(self, return_onnx=False): - onnx_graph = gs.export_onnx(self.graph) - if onnx_graph.ByteSize() > 2147483648: - raise TypeError("ERROR: model size exceeds supported 2GB limit") - else: - onnx_graph = shape_inference.infer_shapes(onnx_graph) - - self.graph = gs.import_onnx(onnx_graph) - if return_onnx: - return onnx_graph - - -class BaseModel(): - def __init__( - self, - model, - fp16=False, - device='cuda', - max_batch_size=16, - embedding_dim=768, - text_maxlen=77, - ): - self.model = model - self.name = "SD Model" - self.fp16 = fp16 - self.device = device - - self.min_batch = 1 - self.max_batch = max_batch_size - self.min_image_shape = 256 # min image resolution: 256x256 - self.max_image_shape = 1024 # max image resolution: 1024x1024 - self.min_latent_shape = self.min_image_shape // 8 - self.max_latent_shape = self.max_image_shape // 8 - - self.embedding_dim = embedding_dim - self.text_maxlen = text_maxlen - - def get_model(self): - return self.model - - def get_input_names(self): - pass - - def get_output_names(self): - pass - - def get_dynamic_axes(self): - return None - - def get_sample_input(self, batch_size, image_height, image_width): - pass - - def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): - return None - - def get_shape_dict(self, batch_size, image_height, image_width): - return None - - def optimize(self, onnx_graph): - opt = Optimizer(onnx_graph) - opt.cleanup() - opt.fold_constants() - opt.infer_shapes() - onnx_opt_graph = opt.cleanup(return_onnx=True) - return onnx_opt_graph - - def check_dims(self, batch_size, image_height, image_width): - assert batch_size >= self.min_batch and batch_size <= self.max_batch - assert image_height % 8 == 0 or image_width % 8 == 0 - latent_height = image_height // 8 - latent_width = image_width // 8 - assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape - assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape - return (latent_height, latent_width) - - def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_shape): - min_batch = batch_size if static_batch else self.min_batch - max_batch = batch_size if static_batch else self.max_batch - latent_height = image_height // 8 - latent_width = image_width // 8 - min_image_height = image_height if static_shape else self.min_image_shape - max_image_height = image_height if static_shape else self.max_image_shape - min_image_width = image_width if static_shape else self.min_image_shape - max_image_width = image_width if static_shape else self.max_image_shape - min_latent_height = latent_height if static_shape else self.min_latent_shape - max_latent_height = latent_height if static_shape else self.max_latent_shape - min_latent_width = latent_width if static_shape else self.min_latent_shape - max_latent_width = latent_width if static_shape else self.max_latent_shape - return (min_batch, max_batch, min_image_height, max_image_height, min_image_width, max_image_width, min_latent_height, max_latent_height, min_latent_width, max_latent_width) - - -def getOnnxPath(model_name, onnx_dir, opt=True): - return os.path.join(onnx_dir, model_name+('.opt' if opt else '')+'.onnx') - -def getEnginePath(model_name, engine_dir): - return os.path.join(engine_dir, model_name+'.plan') - -def build_engines( - models: dict, - engine_dir, - onnx_dir, - onnx_opset, - opt_image_height, - opt_image_width, - opt_batch_size=1, - force_engine_rebuild=False, - static_batch=False, - static_shape=True, - enable_preview=False, - enable_all_tactics=False, - timing_cache=None, - max_workspace_size=0 -): - built_engines = {} - if not os.path.isdir(onnx_dir): - os.makedirs(onnx_dir) - if not os.path.isdir(engine_dir): - os.makedirs(engine_dir) - - # Export models to ONNX - for model_name, model_obj in models.items(): - engine_path = getEnginePath(model_name, engine_dir) - if force_engine_rebuild or not os.path.exists(engine_path): - logger.warning("Building Engines...") - logger.warning("Engine build can take a while to complete") - onnx_path = getOnnxPath(model_name, onnx_dir, opt=False) - onnx_opt_path = getOnnxPath(model_name, onnx_dir) - if force_engine_rebuild or not os.path.exists(onnx_opt_path): - if force_engine_rebuild or not os.path.exists(onnx_path): - logger.warning(f"Exporting model: {onnx_path}") - model = model_obj.get_model() - with torch.inference_mode(), torch.autocast("cuda"): - inputs = model_obj.get_sample_input(opt_batch_size, opt_image_height, opt_image_width) - torch.onnx.export(model, - inputs, - onnx_path, - export_params=True, - opset_version=onnx_opset, - do_constant_folding=True, - input_names=model_obj.get_input_names(), - output_names=model_obj.get_output_names(), - dynamic_axes=model_obj.get_dynamic_axes(), - ) - del model - torch.cuda.empty_cache() - gc.collect() - else: - logger.warning(f"Found cached model: {onnx_path}") - - # Optimize onnx - if force_engine_rebuild or not os.path.exists(onnx_opt_path): - logger.warning(f"Generating optimizing model: {onnx_opt_path}") - onnx_opt_graph = model_obj.optimize(onnx.load(onnx_path)) - onnx.save(onnx_opt_graph, onnx_opt_path) - else: - logger.warning(f"Found cached optimized model: {onnx_opt_path} ") - - # Build TensorRT engines - for model_name, model_obj in models.items(): - engine_path = getEnginePath(model_name, engine_dir) - engine = Engine(engine_path) - onnx_path = getOnnxPath(model_name, onnx_dir, opt=False) - onnx_opt_path = getOnnxPath(model_name, onnx_dir) - - if force_engine_rebuild or not os.path.exists(engine.engine_path): - engine.build(onnx_opt_path, - fp16=True, - input_profile=model_obj.get_input_profile( - opt_batch_size, opt_image_height, opt_image_width, - static_batch=static_batch, static_shape=static_shape - ), - enable_preview=enable_preview, - timing_cache=timing_cache, - workspace_size=max_workspace_size) - built_engines[model_name] = engine - - # Load and activate TensorRT engines - for model_name, model_obj in models.items(): - engine = built_engines[model_name] - engine.load() - engine.activate() - - return built_engines - -def runEngine(engine, feed_dict, stream): - return engine.infer(feed_dict, stream) From 980280fef2bfa52eab5da9c3a31dbc896de248b2 Mon Sep 17 00:00:00 2001 From: Asfiya Baig Date: Wed, 12 Apr 2023 10:24:02 -0700 Subject: [PATCH 5/5] apply make style Signed-off-by: Asfiya Baig --- .../stable_diffusion_tensorrt_txt2img.py | 64 +++++++++++++------ 1 file changed, 44 insertions(+), 20 deletions(-) diff --git a/examples/community/stable_diffusion_tensorrt_txt2img.py b/examples/community/stable_diffusion_tensorrt_txt2img.py index a81a4a4c0fbf..7aef2bec743f 100644 --- a/examples/community/stable_diffusion_tensorrt_txt2img.py +++ b/examples/community/stable_diffusion_tensorrt_txt2img.py @@ -15,35 +15,42 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import gc -import numpy as np -from copy import copy +import os from collections import OrderedDict +from copy import copy from typing import List, Optional, Union +import numpy as np import onnx -from onnx import shape_inference -import torch -import tensorrt as trt import onnx_graphsurgeon as gs +import tensorrt as trt +import torch +from huggingface_hub import snapshot_download +from onnx import shape_inference from polygraphy import cuda -from polygraphy.backend.onnx.loader import fold_constants from polygraphy.backend.common import bytes_from_path -from polygraphy.backend.trt import CreateConfig, Profile -from polygraphy.backend.trt import engine_from_bytes, engine_from_network, network_from_onnx_path, save_engine +from polygraphy.backend.onnx.loader import fold_constants +from polygraphy.backend.trt import ( + CreateConfig, + Profile, + engine_from_bytes, + engine_from_network, + network_from_onnx_path, + save_engine, +) from polygraphy.backend.trt import util as trt_util from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer -from huggingface_hub import snapshot_download -from diffusers.utils import DIFFUSERS_CACHE, logging -from diffusers.schedulers import DDIMScheduler from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.pipelines.stable_diffusion import ( StableDiffusionPipeline, StableDiffusionPipelineOutput, StableDiffusionSafetyChecker, ) +from diffusers.schedulers import DDIMScheduler +from diffusers.utils import DIFFUSERS_CACHE, logging + """ Installation instructions @@ -162,7 +169,7 @@ def infer(self, feed_dict, stream): bindings = [0] * start_binding + [buf.ptr for buf in device_buffers.values()] noerror = self.context.execute_async_v2(bindings=bindings, stream_handle=stream.ptr) if not noerror: - raise ValueError(f"ERROR: inference failed.") + raise ValueError("ERROR: inference failed.") return self.tensors @@ -469,9 +476,18 @@ def get_dynamic_axes(self): def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - min_batch, max_batch, _, _, _, _, min_latent_height, max_latent_height, min_latent_width, max_latent_width = self.get_minmax_dims( - batch_size, image_height, image_width, static_batch, static_shape - ) + ( + min_batch, + max_batch, + _, + _, + _, + _, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape) return { "sample": [ (2 * min_batch, self.unet_dim, min_latent_height, min_latent_width), @@ -534,9 +550,18 @@ def get_dynamic_axes(self): def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - min_batch, max_batch, _, _, _, _, min_latent_height, max_latent_height, min_latent_width, max_latent_width = self.get_minmax_dims( - batch_size, image_height, image_width, static_batch, static_shape - ) + ( + min_batch, + max_batch, + _, + _, + _, + _, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape) return { "latent": [ (min_batch, 4, min_latent_height, min_latent_width), @@ -774,7 +799,6 @@ def __denoise_latent( latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) # Predict the noise residual - embeddings_dtype = np.float16 timestep_float = timestep.float() if timestep.dtype != torch.float32 else timestep sample_inp = device_view(latent_model_input)