diff --git a/examples/community/stable_diffusion_tensorrt_img2img.py b/examples/community/stable_diffusion_tensorrt_img2img.py index 041cf3a12dbd..507177791f5e 100755 --- a/examples/community/stable_diffusion_tensorrt_img2img.py +++ b/examples/community/stable_diffusion_tensorrt_img2img.py @@ -41,7 +41,7 @@ save_engine, ) from polygraphy.backend.trt import util as trt_util -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.pipelines.stable_diffusion import ( @@ -709,6 +709,7 @@ def __init__( scheduler: DDIMScheduler, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, + image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, stages=["clip", "unet", "vae", "vae_encoder"], image_height: int = 512, @@ -724,7 +725,15 @@ def __init__( timing_cache: str = "timing_cache", ): super().__init__( - vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker + vae, + text_encoder, + tokenizer, + unet, + scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + requires_safety_checker=requires_safety_checker, ) self.vae.forward = self.vae.decode diff --git a/examples/community/stable_diffusion_tensorrt_inpaint.py b/examples/community/stable_diffusion_tensorrt_inpaint.py index 71fa1b0a5f11..b4e16c76159c 100755 --- a/examples/community/stable_diffusion_tensorrt_inpaint.py +++ b/examples/community/stable_diffusion_tensorrt_inpaint.py @@ -41,7 +41,7 @@ save_engine, ) from polygraphy.backend.trt import util as trt_util -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.pipelines.stable_diffusion import ( @@ -710,6 +710,7 @@ def __init__( scheduler: DDIMScheduler, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, + image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, stages=["clip", "unet", "vae", "vae_encoder"], image_height: int = 512, @@ -725,7 +726,15 @@ def __init__( timing_cache: str = "timing_cache", ): super().__init__( - vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker + vae, + text_encoder, + tokenizer, + unet, + scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + requires_safety_checker=requires_safety_checker, ) self.vae.forward = self.vae.decode diff --git a/examples/community/stable_diffusion_tensorrt_txt2img.py b/examples/community/stable_diffusion_tensorrt_txt2img.py index b51f3176b958..c38261463384 100755 --- a/examples/community/stable_diffusion_tensorrt_txt2img.py +++ b/examples/community/stable_diffusion_tensorrt_txt2img.py @@ -40,7 +40,7 @@ save_engine, ) from polygraphy.backend.trt import util as trt_util -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.pipelines.stable_diffusion import ( @@ -624,6 +624,7 @@ def __init__( scheduler: DDIMScheduler, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, + image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, stages=["clip", "unet", "vae"], image_height: int = 768, @@ -639,7 +640,15 @@ def __init__( timing_cache: str = "timing_cache", ): super().__init__( - vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker + vae, + text_encoder, + tokenizer, + unet, + scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + requires_safety_checker=requires_safety_checker, ) self.vae.forward = self.vae.decode