diff --git a/src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py b/src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py index 970272999c67..3bd95dd55358 100644 --- a/src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py +++ b/src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 import warnings +from typing import Optional import torch @@ -21,13 +22,20 @@ class KarrasVePipeline(DiffusionPipeline): unet: UNet2DModel scheduler: KarrasVeScheduler - def __init__(self, unet, scheduler): + def __init__(self, unet: UNet2DModel, scheduler: KarrasVeScheduler): super().__init__() scheduler = scheduler.set_format("pt") self.register_modules(unet=unet, scheduler=scheduler) @torch.no_grad() - def __call__(self, batch_size=1, num_inference_steps=50, generator=None, output_type="pil", **kwargs): + def __call__( + self, + batch_size: int = 1, + num_inference_steps: int = 50, + generator: Optional[torch.Generator] = None, + output_type: Optional[str] = "pil", + **kwargs, + ): if "torch_device" in kwargs: device = kwargs.pop("torch_device") warnings.warn(