From 1f13b1ac9cb6aeeaf14241d66f0a991670de7b17 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 31 Aug 2022 10:57:45 +0200 Subject: [PATCH 1/2] [Type hint] Karras VE pipeline --- .../pipeline_stochastic_karras_ve.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py b/src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py index 970272999c67..f90e932f2891 100644 --- a/src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py +++ b/src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 import warnings +from typing import Optional import torch @@ -21,13 +22,20 @@ class KarrasVePipeline(DiffusionPipeline): unet: UNet2DModel scheduler: KarrasVeScheduler - def __init__(self, unet, scheduler): + def __init__(self, unet: UNet2DModel, scheduler: KarrasVeScheduler): super().__init__() scheduler = scheduler.set_format("pt") self.register_modules(unet=unet, scheduler=scheduler) @torch.no_grad() - def __call__(self, batch_size=1, num_inference_steps=50, generator=None, output_type="pil", **kwargs): + def __call__( + self, + batch_size: Optional[int] = 1, + num_inference_steps: Optional[int] = 50, + generator: Optional[torch.Generator] = None, + output_type: Optional[str] = "pil", + **kwargs, + ): if "torch_device" in kwargs: device = kwargs.pop("torch_device") warnings.warn( From 638555e5967fecec537540dd18fc66c944098324 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 31 Aug 2022 12:35:29 +0200 Subject: [PATCH 2/2] Apply suggestions from code review Co-authored-by: Anton Lozhkov --- .../stochatic_karras_ve/pipeline_stochastic_karras_ve.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py b/src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py index f90e932f2891..3bd95dd55358 100644 --- a/src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py +++ b/src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py @@ -30,8 +30,8 @@ def __init__(self, unet: UNet2DModel, scheduler: KarrasVeScheduler): @torch.no_grad() def __call__( self, - batch_size: Optional[int] = 1, - num_inference_steps: Optional[int] = 50, + batch_size: int = 1, + num_inference_steps: int = 50, generator: Optional[torch.Generator] = None, output_type: Optional[str] = "pil", **kwargs,