1818# --------------------------------------------------------------------------
1919
2020
21+ import logging
2122import math
2223from typing import Dict , Union
2324
2425import matplotlib
2526import numpy as np
2627import torch
2728from PIL import Image
29+ from PIL .Image import Resampling
2830from scipy .optimize import minimize
2931from torch .utils .data import DataLoader , TensorDataset
3032from tqdm .auto import tqdm
3436 AutoencoderKL ,
3537 DDIMScheduler ,
3638 DiffusionPipeline ,
39+ LCMScheduler ,
3740 UNet2DConditionModel ,
3841)
3942from diffusers .utils import BaseOutput , check_min_version
4043
4144
4245# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
43- check_min_version ("0.28.0.dev0 " )
46+ check_min_version ("0.25.0 " )
4447
4548
4649class MarigoldDepthOutput (BaseOutput ):
@@ -61,6 +64,19 @@ class MarigoldDepthOutput(BaseOutput):
6164 uncertainty : Union [None , np .ndarray ]
6265
6366
67+ def get_pil_resample_method (method_str : str ) -> Resampling :
68+ resample_method_dic = {
69+ "bilinear" : Resampling .BILINEAR ,
70+ "bicubic" : Resampling .BICUBIC ,
71+ "nearest" : Resampling .NEAREST ,
72+ }
73+ resample_method = resample_method_dic .get (method_str , None )
74+ if resample_method is None :
75+ raise ValueError (f"Unknown resampling method: { resample_method } " )
76+ else :
77+ return resample_method
78+
79+
6480class MarigoldPipeline (DiffusionPipeline ):
6581 """
6682 Pipeline for monocular depth estimation using Marigold: https://marigoldmonodepth.github.io.
@@ -113,7 +129,9 @@ def __call__(
113129 ensemble_size : int = 10 ,
114130 processing_res : int = 768 ,
115131 match_input_res : bool = True ,
132+ resample_method : str = "bilinear" ,
116133 batch_size : int = 0 ,
134+ seed : Union [int , None ] = None ,
117135 color_map : str = "Spectral" ,
118136 show_progress_bar : bool = True ,
119137 ensemble_kwargs : Dict = None ,
@@ -129,14 +147,18 @@ def __call__(
129147 If set to 0: will not resize at all.
130148 match_input_res (`bool`, *optional*, defaults to `True`):
131149 Resize depth prediction to match input resolution.
132- Only valid if `limit_input_res` is not None.
150+ Only valid if `processing_res` > 0.
151+ resample_method: (`str`, *optional*, defaults to `bilinear`):
152+ Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`.
133153 denoising_steps (`int`, *optional*, defaults to `10`):
134154 Number of diffusion denoising steps (DDIM) during inference.
135155 ensemble_size (`int`, *optional*, defaults to `10`):
136156 Number of predictions to be ensembled.
137157 batch_size (`int`, *optional*, defaults to `0`):
138158 Inference batch size, no bigger than `num_ensemble`.
139159 If set to 0, the script will automatically decide the proper batch size.
160+ seed (`int`, *optional*, defaults to `None`)
161+ Reproducibility seed.
140162 show_progress_bar (`bool`, *optional*, defaults to `True`):
141163 Display a progress bar of diffusion denoising.
142164 color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized depth map generation):
@@ -146,8 +168,7 @@ def __call__(
146168 Returns:
147169 `MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including:
148170 - **depth_np** (`np.ndarray`) Predicted depth map, with depth values in the range of [0, 1]
149- - **depth_colored** (`None` or `PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and
150- values in [0, 1]. None if `color_map` is `None`
171+ - **depth_colored** (`PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and values in [0, 1], None if `color_map` is `None`
151172 - **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation)
152173 coming from ensembling. None if `ensemble_size = 1`
153174 """
@@ -158,13 +179,21 @@ def __call__(
158179 if not match_input_res :
159180 assert processing_res is not None , "Value error: `resize_output_back` is only valid with "
160181 assert processing_res >= 0
161- assert denoising_steps >= 1
162182 assert ensemble_size >= 1
163183
184+ # Check if denoising step is reasonable
185+ self ._check_inference_step (denoising_steps )
186+
187+ resample_method : Resampling = get_pil_resample_method (resample_method )
188+
164189 # ----------------- Image Preprocess -----------------
165190 # Resize image
166191 if processing_res > 0 :
167- input_image = self .resize_max_res (input_image , max_edge_resolution = processing_res )
192+ input_image = self .resize_max_res (
193+ input_image ,
194+ max_edge_resolution = processing_res ,
195+ resample_method = resample_method ,
196+ )
168197 # Convert the image to RGB, to 1.remove the alpha channel 2.convert B&W to 3-channel
169198 input_image = input_image .convert ("RGB" )
170199 image = np .asarray (input_image )
@@ -203,9 +232,10 @@ def __call__(
203232 rgb_in = batched_img ,
204233 num_inference_steps = denoising_steps ,
205234 show_pbar = show_progress_bar ,
235+ seed = seed ,
206236 )
207- depth_pred_ls .append (depth_pred_raw .detach (). clone () )
208- depth_preds = torch .concat (depth_pred_ls , axis = 0 ).squeeze ()
237+ depth_pred_ls .append (depth_pred_raw .detach ())
238+ depth_preds = torch .concat (depth_pred_ls , dim = 0 ).squeeze ()
209239 torch .cuda .empty_cache () # clear vram cache for ensembling
210240
211241 # ----------------- Test-time ensembling -----------------
@@ -227,7 +257,7 @@ def __call__(
227257 # Resize back to original resolution
228258 if match_input_res :
229259 pred_img = Image .fromarray (depth_pred )
230- pred_img = pred_img .resize (input_size )
260+ pred_img = pred_img .resize (input_size , resample = resample_method )
231261 depth_pred = np .asarray (pred_img )
232262
233263 # Clip output range
@@ -243,12 +273,32 @@ def __call__(
243273 depth_colored_img = Image .fromarray (depth_colored_hwc )
244274 else :
245275 depth_colored_img = None
276+
246277 return MarigoldDepthOutput (
247278 depth_np = depth_pred ,
248279 depth_colored = depth_colored_img ,
249280 uncertainty = pred_uncert ,
250281 )
251282
283+ def _check_inference_step (self , n_step : int ):
284+ """
285+ Check if denoising step is reasonable
286+ Args:
287+ n_step (`int`): denoising steps
288+ """
289+ assert n_step >= 1
290+
291+ if isinstance (self .scheduler , DDIMScheduler ):
292+ if n_step < 10 :
293+ logging .warning (
294+ f"Too few denoising steps: { n_step } . Recommended to use the LCM checkpoint for few-step inference."
295+ )
296+ elif isinstance (self .scheduler , LCMScheduler ):
297+ if not 1 <= n_step <= 4 :
298+ logging .warning (f"Non-optimal setting of denoising steps: { n_step } . Recommended setting is 1-4 steps." )
299+ else :
300+ raise RuntimeError (f"Unsupported scheduler type: { type (self .scheduler )} " )
301+
252302 def _encode_empty_text (self ):
253303 """
254304 Encode text embedding for empty prompt.
@@ -265,7 +315,13 @@ def _encode_empty_text(self):
265315 self .empty_text_embed = self .text_encoder (text_input_ids )[0 ].to (self .dtype )
266316
267317 @torch .no_grad ()
268- def single_infer (self , rgb_in : torch .Tensor , num_inference_steps : int , show_pbar : bool ) -> torch .Tensor :
318+ def single_infer (
319+ self ,
320+ rgb_in : torch .Tensor ,
321+ num_inference_steps : int ,
322+ seed : Union [int , None ],
323+ show_pbar : bool ,
324+ ) -> torch .Tensor :
269325 """
270326 Perform an individual depth prediction without ensembling.
271327
@@ -286,10 +342,20 @@ def single_infer(self, rgb_in: torch.Tensor, num_inference_steps: int, show_pbar
286342 timesteps = self .scheduler .timesteps # [T]
287343
288344 # Encode image
289- rgb_latent = self ._encode_rgb (rgb_in )
345+ rgb_latent = self .encode_rgb (rgb_in )
290346
291347 # Initial depth map (noise)
292- depth_latent = torch .randn (rgb_latent .shape , device = device , dtype = self .dtype ) # [B, 4, h, w]
348+ if seed is None :
349+ rand_num_generator = None
350+ else :
351+ rand_num_generator = torch .Generator (device = device )
352+ rand_num_generator .manual_seed (seed )
353+ depth_latent = torch .randn (
354+ rgb_latent .shape ,
355+ device = device ,
356+ dtype = self .dtype ,
357+ generator = rand_num_generator ,
358+ ) # [B, 4, h, w]
293359
294360 # Batched empty text embedding
295361 if self .empty_text_embed is None :
@@ -314,9 +380,9 @@ def single_infer(self, rgb_in: torch.Tensor, num_inference_steps: int, show_pbar
314380 noise_pred = self .unet (unet_input , t , encoder_hidden_states = batch_empty_text_embed ).sample # [B, 4, h, w]
315381
316382 # compute the previous noisy sample x_t -> x_t-1
317- depth_latent = self .scheduler .step (noise_pred , t , depth_latent ).prev_sample
318- torch . cuda . empty_cache ()
319- depth = self ._decode_depth (depth_latent )
383+ depth_latent = self .scheduler .step (noise_pred , t , depth_latent , generator = rand_num_generator ).prev_sample
384+
385+ depth = self .decode_depth (depth_latent )
320386
321387 # clip prediction
322388 depth = torch .clip (depth , - 1.0 , 1.0 )
@@ -325,7 +391,7 @@ def single_infer(self, rgb_in: torch.Tensor, num_inference_steps: int, show_pbar
325391
326392 return depth
327393
328- def _encode_rgb (self , rgb_in : torch .Tensor ) -> torch .Tensor :
394+ def encode_rgb (self , rgb_in : torch .Tensor ) -> torch .Tensor :
329395 """
330396 Encode RGB image into latent.
331397
@@ -344,7 +410,7 @@ def _encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
344410 rgb_latent = mean * self .rgb_latent_scale_factor
345411 return rgb_latent
346412
347- def _decode_depth (self , depth_latent : torch .Tensor ) -> torch .Tensor :
413+ def decode_depth (self , depth_latent : torch .Tensor ) -> torch .Tensor :
348414 """
349415 Decode depth latent into depth map.
350416
@@ -365,7 +431,7 @@ def _decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor:
365431 return depth_mean
366432
367433 @staticmethod
368- def resize_max_res (img : Image .Image , max_edge_resolution : int ) -> Image .Image :
434+ def resize_max_res (img : Image .Image , max_edge_resolution : int , resample_method = Resampling . BILINEAR ) -> Image .Image :
369435 """
370436 Resize image to limit maximum edge length while keeping aspect ratio.
371437
@@ -374,6 +440,8 @@ def resize_max_res(img: Image.Image, max_edge_resolution: int) -> Image.Image:
374440 Image to be resized.
375441 max_edge_resolution (`int`):
376442 Maximum edge length (pixel).
443+ resample_method (`PIL.Image.Resampling`):
444+ Resampling method used to resize images.
377445
378446 Returns:
379447 `Image.Image`: Resized image.
@@ -384,7 +452,7 @@ def resize_max_res(img: Image.Image, max_edge_resolution: int) -> Image.Image:
384452 new_width = int (original_width * downscale_factor )
385453 new_height = int (original_height * downscale_factor )
386454
387- resized_img = img .resize ((new_width , new_height ))
455+ resized_img = img .resize ((new_width , new_height ), resample = resample_method )
388456 return resized_img
389457
390458 @staticmethod
0 commit comments