diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py index 0c50e424e249..ecc0b31f3456 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py @@ -18,6 +18,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): + vae_encoder: OnnxRuntimeModel vae_decoder: OnnxRuntimeModel text_encoder: OnnxRuntimeModel tokenizer: CLIPTokenizer @@ -268,6 +269,7 @@ def __call__( class StableDiffusionOnnxPipeline(OnnxStableDiffusionPipeline): def __init__( self, + vae_encoder: OnnxRuntimeModel, vae_decoder: OnnxRuntimeModel, text_encoder: OnnxRuntimeModel, tokenizer: CLIPTokenizer, @@ -279,6 +281,7 @@ def __init__( deprecation_message = "Please use `OnnxStableDiffusionPipeline` instead of `StableDiffusionOnnxPipeline`." deprecate("StableDiffusionOnnxPipeline", "1.0.0", deprecation_message) super().__init__( + vae_encoder=vae_encoder, vae_decoder=vae_decoder, text_encoder=text_encoder, tokenizer=tokenizer,