diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 2f4651ba3417..f551c052976f 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -93,6 +93,8 @@ title: Trajectory Consistency Distillation-LoRA - local: using-diffusers/svd title: Stable Video Diffusion + - local: using-diffusers/marigold_usage + title: Marigold Computer Vision title: Specific pipeline examples - sections: - local: training/overview @@ -295,6 +297,8 @@ title: Latent Diffusion - local: api/pipelines/ledits_pp title: LEDITS++ + - local: api/pipelines/marigold + title: Marigold - local: api/pipelines/panorama title: MultiDiffusion - local: api/pipelines/musicldm @@ -445,4 +449,4 @@ title: Video Processor title: Internal classes isExpanded: false - title: API + title: API \ No newline at end of file diff --git a/docs/source/en/api/pipelines/marigold.md b/docs/source/en/api/pipelines/marigold.md new file mode 100644 index 000000000000..e235368eb047 --- /dev/null +++ b/docs/source/en/api/pipelines/marigold.md @@ -0,0 +1,76 @@ + + +# Marigold Pipelines for Computer Vision Tasks + +![marigold](https://marigoldmonodepth.github.io/images/teaser_collage_compressed.jpg) + +Marigold was proposed in [Repurposing Diffusion-Based Image Generators for Monocular Depth Estimation](https://huggingface.co/papers/2312.02145), a CVPR 2024 Oral paper by [Bingxin Ke](http://www.kebingxin.com/), [Anton Obukhov](https://www.obukhov.ai/), [Shengyu Huang](https://shengyuh.github.io/), [Nando Metzger](https://nandometzger.github.io/), [Rodrigo Caye Daudt](https://rcdaudt.github.io/), and [Konrad Schindler](https://scholar.google.com/citations?user=FZuNgqIAAAAJ&hl=en). +The idea is to repurpose the rich generative prior of Text-to-Image Latent Diffusion Models (LDMs) for traditional computer vision tasks. +Initially, this idea was explored to fine-tune Stable Diffusion for Monocular Depth Estimation, as shown in the teaser above. +Later, +- [Tianfu Wang](https://tianfwang.github.io/) trained the first Latent Consistency Model (LCM) of Marigold, which unlocked fast single-step inference; +- [Kevin Qu](https://www.linkedin.com/in/kevin-qu-b3417621b/?locale=en_US) extended the approach to Surface Normals Estimation; +- [Anton Obukhov](https://www.obukhov.ai/) contributed the pipelines and documentation into diffusers (enabled and supported by [YiYi Xu](https://yiyixuxu.github.io/) and [Sayak Paul](https://sayak.dev/)). + +The abstract from the paper is: + +*Monocular depth estimation is a fundamental computer vision task. Recovering 3D depth from a single image is geometrically ill-posed and requires scene understanding, so it is not surprising that the rise of deep learning has led to a breakthrough. The impressive progress of monocular depth estimators has mirrored the growth in model capacity, from relatively modest CNNs to large Transformer architectures. Still, monocular depth estimators tend to struggle when presented with images with unfamiliar content and layout, since their knowledge of the visual world is restricted by the data seen during training, and challenged by zero-shot generalization to new domains. This motivates us to explore whether the extensive priors captured in recent generative diffusion models can enable better, more generalizable depth estimation. We introduce Marigold, a method for affine-invariant monocular depth estimation that is derived from Stable Diffusion and retains its rich prior knowledge. The estimator can be fine-tuned in a couple of days on a single GPU using only synthetic training data. It delivers state-of-the-art performance across a wide range of datasets, including over 20% performance gains in specific cases. Project page: https://marigoldmonodepth.github.io.* + +## Available Pipelines + +Each pipeline supports one Computer Vision task, which takes an input RGB image as input and produces a *prediction* of the modality of interest, such as a depth map of the input image. +Currently, the following tasks are implemented: + +| Pipeline | Predicted Modalities | Demos | +|---------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------|:--------------------------------------------------------------------------------------------------------------------------------------------------:| +| [MarigoldDepthPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/marigold/pipeline_marigold_depth.py) | [Depth](https://en.wikipedia.org/wiki/Depth_map), [Disparity](https://en.wikipedia.org/wiki/Binocular_disparity) | [Fast Demo (LCM)](https://huggingface.co/spaces/prs-eth/marigold-lcm), [Slow Original Demo (DDIM)](https://huggingface.co/spaces/prs-eth/marigold) | +| [MarigoldNormalsPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/marigold/pipeline_marigold_normals.py) | [Surface normals](https://en.wikipedia.org/wiki/Normal_mapping) | [Fast Demo (LCM)](https://huggingface.co/spaces/prs-eth/marigold-normals-lcm) | + + +## Available Checkpoints + +The original checkpoints can be found under the [PRS-ETH](https://huggingface.co/prs-eth/) Hugging Face organization. + + + +Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines. Also, to know more about reducing the memory usage of this pipeline, refer to the ["Reduce memory usage"] section [here](../../using-diffusers/svd#reduce-memory-usage). + + + + + +Marigold pipelines were designed and tested only with `DDIMScheduler` and `LCMScheduler`. +Depending on the scheduler, the number of inference steps required to get reliable predictions varies, and there is no universal value that works best across schedulers. +Because of that, the default value of `num_inference_steps` in the `__call__` method of the pipeline is set to `None` (see the API reference). +Unless set explicitly, its value will be taken from the checkpoint configuration `model_index.json`. +This is done to ensure high-quality predictions when calling the pipeline with just the `image` argument. + + + +See also Marigold [usage examples](marigold_usage). + +## MarigoldDepthPipeline +[[autodoc]] MarigoldDepthPipeline + - all + - __call__ + +## MarigoldNormalsPipeline +[[autodoc]] MarigoldNormalsPipeline + - all + - __call__ + +## MarigoldDepthOutput +[[autodoc]] pipelines.marigold.pipeline_marigold_depth.MarigoldDepthOutput + +## MarigoldNormalsOutput +[[autodoc]] pipelines.marigold.pipeline_marigold_normals.MarigoldNormalsOutput \ No newline at end of file diff --git a/docs/source/en/using-diffusers/marigold_usage.md b/docs/source/en/using-diffusers/marigold_usage.md new file mode 100644 index 000000000000..9c66d9d394aa --- /dev/null +++ b/docs/source/en/using-diffusers/marigold_usage.md @@ -0,0 +1,399 @@ + + +# Marigold Pipelines for Computer Vision Tasks + +[Marigold](marigold) is a novel diffusion-based dense prediction approach, and a set of pipelines for various computer vision tasks, such as monocular depth estimation. + +This guide will show you how to use Marigold to obtain fast and high-quality predictions for images and videos. + +Each pipeline supports one Computer Vision task, which takes an input RGB image as input and produces a *prediction* of the modality of interest, such as a depth map of the input image. +Currently, the following tasks are implemented: + +| Pipeline | Predicted Modalities | Demos | +|---------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------|:--------------------------------------------------------------------------------------------------------------------------------------------------:| +| [MarigoldDepthPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/marigold/pipeline_marigold_depth.py) | [Depth](https://en.wikipedia.org/wiki/Depth_map), [Disparity](https://en.wikipedia.org/wiki/Binocular_disparity) | [Fast Demo (LCM)](https://huggingface.co/spaces/prs-eth/marigold-lcm), [Slow Original Demo (DDIM)](https://huggingface.co/spaces/prs-eth/marigold) | +| [MarigoldNormalsPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/marigold/pipeline_marigold_normals.py) | [Surface normals](https://en.wikipedia.org/wiki/Normal_mapping) | [Fast Demo (LCM)](https://huggingface.co/spaces/prs-eth/marigold-normals-lcm) | + +The original checkpoints can be found under the [PRS-ETH](https://huggingface.co/prs-eth/) Hugging Face organization. +These checkpoints are meant to work with diffusers pipelines and the [original codebase](https://github.com/prs-eth/marigold). +The original code can also be used to train new checkpoints. + +| Checkpoint | Modality | Comment | +|-----------------------------------------------------------------------------------------------|----------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| [prs-eth/marigold-v1-0](https://huggingface.co/prs-eth/marigold-v1-0) | Depth | The first Marigold Depth checkpoint, which predicts *affine-invariant depth* maps. The performance of this checkpoint in benchmarks was studied in the original [paper](https://huggingface.co/papers/2312.02145). Designed to be used with the `DDIMScheduler` at inference, it requires at least 10 steps to get reliable predictions. Affine-invariant depth prediction has a range of values in each pixel between 0 (near plane) and 1 (far plane); both planes are chosen by the model as part of the inference process. See the `MarigoldImageProcessor` reference for visualization utilities. | +| [prs-eth/marigold-lcm-v1-0](https://huggingface.co/prs-eth/marigold-lcm-v1-0) | Depth | The fast Marigold Depth checkpoint, fine-tuned from `prs-eth/marigold-v1-0`. Designed to be used with the `LCMScheduler` at inference, it requires as little as 1 step to get reliable predictions. The prediction reliability saturates at 4 steps and declines after that. | +| [prs-eth/marigold-normals-v0-1](https://huggingface.co/prs-eth/marigold-normals-v0-1) | Normals | A preview checkpoint for the Marigold Normals pipeline. Designed to be used with the `DDIMScheduler` at inference, it requires at least 10 steps to get reliable predictions. The surface normals predictions are unit-length 3D vectors with values in the range from -1 to 1. *This checkpoint will be phased out after the release of `v1-0` version.* | +| [prs-eth/marigold-normals-lcm-v0-1](https://huggingface.co/prs-eth/marigold-normals-lcm-v0-1) | Normals | The fast Marigold Normals checkpoint, fine-tuned from `prs-eth/marigold-normals-v0-1`. Designed to be used with the `LCMScheduler` at inference, it requires as little as 1 step to get reliable predictions. The prediction reliability saturates at 4 steps and declines after that. *This checkpoint will be phased out after the release of `v1-0` version.* | +The examples below are mostly given for depth prediction, but they can be universally applied with other supported modalities. +We showcase the predictions using the same input image of Albert Einstein generated by Midjourney. +This makes it easier to compare visualizations of the predictions across various modalities and checkpoints. + +
+
+ +
+ Example input image for all Marigold pipelines +
+
+
+ +### Depth Prediction Quick Start + +To get the first depth prediction, load `prs-eth/marigold-depth-lcm-v1-0` checkpoint into `MarigoldDepthPipeline` pipeline, put the image through the pipeline, and save the predictions: + +```python +import diffusers +import torch + +pipe = diffusers.MarigoldDepthPipeline.from_pretrained( + "prs-eth/marigold-depth-lcm-v1-0", variant="fp16", torch_dtype=torch.float16 +).to("cuda") + +image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg") +depth = pipe(image) + +vis = pipe.image_processor.visualize_depth(depth.prediction) +vis[0].save("einstein_depth.png") + +depth_16bit = pipe.image_processor.export_depth_to_16bit_png(depth.prediction) +depth_16bit[0].save("einstein_depth_16bit.png") +``` + +The visualization function for depth [`~pipelines.marigold.marigold_image_processing.MarigoldImageProcessor.visualize_depth`] applies one of [matplotlib's colormaps](https://matplotlib.org/stable/users/explain/colors/colormaps.html) (`Spectral` by default) to map the predicted pixel values from a single-channel `[0, 1]` depth range into an RGB image. +With the `Spectral` colormap, pixels with near depth are painted red, and far pixels are assigned blue color. +The 16-bit PNG file stores the single channel values mapped linearly from the `[0, 1]` range into `[0, 65535]`. +Below are the raw and the visualized predictions; as can be seen, dark areas (mustache) are easier to distinguish in the visualization: + +
+
+ +
+ Predicted depth (16-bit PNG) +
+
+
+ +
+ Predicted depth visualization (Spectral) +
+
+
+ +### Surface Normals Prediction Quick Start + +Load `prs-eth/marigold-normals-lcm-v0-1` checkpoint into `MarigoldNormalsPipeline` pipeline, put the image through the pipeline, and save the predictions: + +```python +import diffusers +import torch + +pipe = diffusers.MarigoldNormalsPipeline.from_pretrained( + "prs-eth/marigold-normals-lcm-v0-1", variant="fp16", torch_dtype=torch.float16 +).to("cuda") + +image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg") +normals = pipe(image) + +vis = pipe.image_processor.visualize_normals(normals.prediction) +vis[0].save("einstein_normals.png") +``` + +The visualization function for normals [`~pipelines.marigold.marigold_image_processing.MarigoldImageProcessor.visualize_normals`] maps the three-dimensional prediction with pixel values in the range `[-1, 1]` into an RGB image. +The visualization function supports flipping surface normals axes to make the visualization compatible with other choices of the frame of reference. +Conceptually, each pixel is painted according to the surface normal vector in the frame of reference, where `X` axis points right, `Y` axis points up, and `Z` axis points at the viewer. +Below is the visualized prediction: + +
+
+ +
+ Predicted surface normals visualization +
+
+
+ +In this example, the nose tip almost certainly has a point on the surface, in which the surface normal vector points straight at the viewer, meaning that its coordinates are `[0, 0, 1]`. +This vector maps to the RGB `[128, 128, 255]`, which corresponds to the violet-blue color. +Similarly, a surface normal on the cheek in the right part of the image has a large `X` component, which increases the red hue. +Points on the shoulders pointing up with a large `Y` promote green color. + +### Speeding up inference + +The above quick start snippets are already optimized for speed: they load the LCM checkpoint, use the `fp16` variant of weights and computation, and perform just one denoising diffusion step. +The `pipe(image)` call completes in 280ms on RTX 3090 GPU. +Internally, the input image is encoded with the Stable Diffusion VAE encoder, then the U-Net performs one denoising step, and finally, the prediction latent is decoded with the VAE decoder into pixel space. +In this case, two out of three module calls are dedicated to converting between pixel and latent space of LDM. +Because Marigold's latent space is compatible with the base Stable Diffusion, it is possible to speed up the pipeline call by more than 3x (85ms on RTX 3090) by using a [lightweight replacement of the SD VAE](autoencoder_tiny): + +```diff + import diffusers + import torch + + pipe = diffusers.MarigoldDepthPipeline.from_pretrained( + "prs-eth/marigold-depth-lcm-v1-0", variant="fp16", torch_dtype=torch.float16 + ).to("cuda") + ++ pipe.vae = diffusers.AutoencoderTiny.from_pretrained( ++ "madebyollin/taesd", torch_dtype=torch.float16 ++ ).cuda() + + image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg") + depth = pipe(image) +``` + +As suggested in [Optimizations](torch2.0), adding `torch.compile` may squeeze extra performance depending on the target hardware: + +```diff + import diffusers + import torch + + pipe = diffusers.MarigoldDepthPipeline.from_pretrained( + "prs-eth/marigold-depth-lcm-v1-0", variant="fp16", torch_dtype=torch.float16 + ).to("cuda") + ++ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) + + image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg") + depth = pipe(image) +``` + +## Qualitative Comparison with Depth Anything + +With the above speed optimizations, Marigold delivers predictions with more details and faster than [Depth Anything](https://huggingface.co/docs/transformers/main/en/model_doc/depth_anything) with the largest checkpoint [LiheYoung/depth-anything-large-hf](https://huggingface.co/LiheYoung/depth-anything-large-hf): + +
+
+ +
+ Marigold LCM fp16 with Tiny AutoEncoder +
+
+
+ +
+ Depth Anything Large +
+
+
+ +## Maximizing Precision and Ensembling + +Marigold pipelines have a built-in ensembling mechanism combining multiple predictions from different random latents. +This is a brute-force way of improving the precision of predictions, capitalizing on the generative nature of diffusion. +The ensembling path is activated automatically when the `ensemble_size` argument is set greater than `1`. +When aiming for maximum precision, it makes sense to adjust `num_inference_steps` simultaneously with `ensemble_size`. +The recommended values vary across checkpoints but primarily depend on the scheduler type. +The effect of ensembling is particularly well-seen with surface normals: + +```python +import diffusers + +model_path = "prs-eth/marigold-normals-v1-0" + +model_paper_kwargs = { + diffusers.schedulers.DDIMScheduler: { + "num_inference_steps": 10, + "ensemble_size": 10, + }, + diffusers.schedulers.LCMScheduler: { + "num_inference_steps": 4, + "ensemble_size": 5, + }, +} + +image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg") + +pipe = diffusers.MarigoldNormalsPipeline.from_pretrained(model_path).to("cuda") +pipe_kwargs = model_paper_kwargs[type(pipe.scheduler)] + +depth = pipe(image, **pipe_kwargs) + +vis = pipe.image_processor.visualize_normals(depth.prediction) +vis[0].save("einstein_normals.png") +``` + +
+
+ +
+ Surface normals, no ensembling +
+
+
+ +
+ Surface normals, with ensembling +
+
+
+ +As can be seen, all areas with fine-grained structurers, such as hair, got more conservative and on average more correct predictions. +Such a result is more suitable for precision-sensitive downstream tasks, such as 3D reconstruction. + +## Quantitative Evaluation + +To evaluate Marigold quantitatively in standard leaderboards and benchmarks (such as NYU, KITTI, and other datasets), follow the evaluation protocol outlined in the paper: load the full precision fp32 model and use appropriate values for `num_inference_steps` and `ensemble_size`. +Optionally seed randomness to ensure reproducibility. Maximizing `batch_size` will deliver maximum device utilization. + +```python +import diffusers +import torch + +device = "cuda" +seed = 2024 +model_path = "prs-eth/marigold-v1-0" + +model_paper_kwargs = { + diffusers.schedulers.DDIMScheduler: { + "num_inference_steps": 50, + "ensemble_size": 10, + }, + diffusers.schedulers.LCMScheduler: { + "num_inference_steps": 4, + "ensemble_size": 10, + }, +} + +image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg") + +generator = torch.Generator(device=device).manual_seed(seed) +pipe = diffusers.MarigoldDepthPipeline.from_pretrained(model_path).to(device) +pipe_kwargs = model_paper_kwargs[type(pipe.scheduler)] + +depth = pipe(image, generator=generator, **pipe_kwargs) + +# evaluate metrics +``` + +## Using Predictive Uncertainty + +The ensembling mechanism built into Marigold pipelines combines multiple predictions obtained from different random latents. +As a side effect, it can be used to quantify epistemic (model) uncertainty; simply specify `ensemble_size` greater than 1 and set `output_uncertainty=True`. +The resulting uncertainty will be available in the `uncertainty` field of the output. +It can be visualized as follows: + +```python +import diffusers +import torch + +pipe = diffusers.MarigoldDepthPipeline.from_pretrained( + "prs-eth/marigold-depth-lcm-v1-0", variant="fp16", torch_dtype=torch.float16 +).to("cuda") + +image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg") +depth = pipe( + image, + ensemble_size=10, # any number greater than 1; higher values yield higher precision + output_uncertainty=True, +) + +uncertainty = pipe.image_processor.visualize_uncertainty(depth.uncertainty) +uncertainty[0].save("einstein_depth_uncertainty.png") +``` + +
+
+ +
+ Depth uncertainty +
+
+
+ +
+ Surface normals uncertainty +
+
+
+ +The interpretation of uncertainty is easy: higher values (white) correspond to pixels, where the model struggles to make consistent predictions. +Evidently, the depth model is the least confident around edges with discontinuity, where the object depth changes drastically. +The surface normals model is the least confident in fine-grained structures, such as hair, and dark areas, such as the collar. + +## Frame-by-frame Video Processing with Temporal Consistency + +Due to Marigold's generative nature, each prediction is unique and defined by the random noise sampled for the latent initialization. +This becomes an obvious drawback compared to traditional end-to-end dense regression networks, as exemplified in the following videos: + +
+
+ +
Input video
+
+
+ +
Marigold Depth applied to input video frames independently
+
+
+ +To address this issue, it is possible to pass `latents` argument to the pipelines, which defines the starting point of diffusion. +Empirically, we found that a convex combination of the very same starting point noise latent and the latent corresponding to the previous frame prediction give sufficiently smooth results, as implemented in the snippet below: + +```python +import imageio +from PIL import Image +from tqdm import tqdm +import diffusers +import torch + +device = "cuda" +path_in = "obama.mp4" +path_out = "obama_depth.gif" + +pipe = diffusers.MarigoldDepthPipeline.from_pretrained( + "prs-eth/marigold-lcm-v1-0", variant="fp16", torch_dtype=torch.float16 +).to(device) +pipe.vae = diffusers.AutoencoderTiny.from_pretrained( + "madebyollin/taesd", torch_dtype=torch.float16 +).to(device) +pipe.set_progress_bar_config(disable=True) + +with imageio.get_reader(path_in) as reader: + size = reader.get_meta_data()['size'] + last_frame_latent = None + latent_common = torch.randn( + (1, 4, 768 * size[1] // (8 * max(size)), 768 * size[0] // (8 * max(size))) + ).to(device=device, dtype=torch.float16) + + out = [] + for frame_id, frame in tqdm(enumerate(reader), desc="Processing Video"): + frame = Image.fromarray(frame) + latents = latent_common + if last_frame_latent is not None: + latents = 0.9 * latents + 0.1 * last_frame_latent + + depth = pipe( + frame, match_input_resolution=False, latents=latents, output_latent=True, + ) + last_frame_latent = depth.latent + out.append(pipe.image_processor.visualize_depth(depth.prediction)[0]) + + diffusers.utils.export_to_gif(out, path_out, fps=reader.get_meta_data()['fps']) +``` + +Here, the diffusion process starts from the given computed latent. +The pipeline sets `output_latent=True` to access `out.latent` and computes its contribution to the next frame's latent initialization. +The result is much more stable now: + +
+
+ +
Marigold Depth applied to input video frames independently
+
+
+ +
Marigold Depth with forced latents initialization
+
+
+ +Hopefully, you will find Marigold useful for solving your downstream tasks, be it a part of a more broad generative workflow, or a broader perception task, such as 3D reconstruction. \ No newline at end of file diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 66c98804eadc..c255c0a39afb 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -259,6 +259,8 @@ "LDMTextToImagePipeline", "LEditsPPPipelineStableDiffusion", "LEditsPPPipelineStableDiffusionXL", + "MarigoldDepthPipeline", + "MarigoldNormalsPipeline", "MusicLDMPipeline", "PaintByExamplePipeline", "PIAPipeline", @@ -637,6 +639,8 @@ LDMTextToImagePipeline, LEditsPPPipelineStableDiffusion, LEditsPPPipelineStableDiffusionXL, + MarigoldDepthPipeline, + MarigoldNormalsPipeline, MusicLDMPipeline, PaintByExamplePipeline, PIAPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index c2dd7ac0d551..36031e546a77 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -24,6 +24,7 @@ "deprecated": [], "latent_diffusion": [], "ledits_pp": [], + "marigold": [], "stable_diffusion": [], "stable_diffusion_xl": [], } @@ -185,6 +186,12 @@ "LEditsPPPipelineStableDiffusionXL", ] ) + _import_structure["marigold"].extend( + [ + "MarigoldDepthPipeline", + "MarigoldNormalsPipeline", + ] + ) _import_structure["musicldm"] = ["MusicLDMPipeline"] _import_structure["paint_by_example"] = ["PaintByExamplePipeline"] _import_structure["pia"] = ["PIAPipeline"] @@ -448,6 +455,10 @@ LEditsPPPipelineStableDiffusion, LEditsPPPipelineStableDiffusionXL, ) + from .marigold import ( + MarigoldDepthPipeline, + MarigoldNormalsPipeline, + ) from .musicldm import MusicLDMPipeline from .paint_by_example import PaintByExamplePipeline from .pia import PIAPipeline diff --git a/src/diffusers/pipelines/marigold/__init__.py b/src/diffusers/pipelines/marigold/__init__.py new file mode 100644 index 000000000000..b5ae03adfc11 --- /dev/null +++ b/src/diffusers/pipelines/marigold/__init__.py @@ -0,0 +1,50 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["marigold_image_processing"] = ["MarigoldImageProcessor"] + _import_structure["pipeline_marigold_depth"] = ["MarigoldDepthOutput", "MarigoldDepthPipeline"] + _import_structure["pipeline_marigold_normals"] = ["MarigoldNormalsOutput", "MarigoldNormalsPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .marigold_image_processing import MarigoldImageProcessor + from .pipeline_marigold_depth import MarigoldDepthOutput, MarigoldDepthPipeline + from .pipeline_marigold_normals import MarigoldNormalsOutput, MarigoldNormalsPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/marigold/marigold_image_processing.py b/src/diffusers/pipelines/marigold/marigold_image_processing.py new file mode 100644 index 000000000000..6768517cecfc --- /dev/null +++ b/src/diffusers/pipelines/marigold/marigold_image_processing.py @@ -0,0 +1,561 @@ +from typing import List, Optional, Tuple, Union + +import numpy as np +import PIL +import torch +import torch.nn.functional as F +from PIL import Image + +from ... import ConfigMixin +from ...configuration_utils import register_to_config +from ...image_processor import PipelineImageInput +from ...utils import CONFIG_NAME, logging +from ...utils.import_utils import is_matplotlib_available + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class MarigoldImageProcessor(ConfigMixin): + config_name = CONFIG_NAME + + @register_to_config + def __init__( + self, + vae_scale_factor: int = 8, + do_normalize: bool = True, + do_range_check: bool = True, + ): + super().__init__() + + @staticmethod + def expand_tensor_or_array(images: Union[torch.Tensor, np.ndarray]) -> Union[torch.Tensor, np.ndarray]: + """ + Expand a tensor or array to a specified number of images. + """ + if isinstance(images, np.ndarray): + if images.ndim == 2: # [H,W] -> [1,H,W,1] + images = images[None, ..., None] + if images.ndim == 3: # [H,W,C] -> [1,H,W,C] + images = images[None] + elif isinstance(images, torch.Tensor): + if images.ndim == 2: # [H,W] -> [1,1,H,W] + images = images[None, None] + elif images.ndim == 3: # [1,H,W] -> [1,1,H,W] + images = images[None] + else: + raise ValueError(f"Unexpected input type: {type(images)}") + return images + + @staticmethod + def pt_to_numpy(images: torch.Tensor) -> np.ndarray: + """ + Convert a PyTorch tensor to a NumPy image. + """ + images = images.cpu().permute(0, 2, 3, 1).float().numpy() + return images + + @staticmethod + def numpy_to_pt(images: np.ndarray) -> torch.Tensor: + """ + Convert a NumPy image to a PyTorch tensor. + """ + if np.issubdtype(images.dtype, np.integer) and not np.issubdtype(images.dtype, np.unsignedinteger): + raise ValueError(f"Input image dtype={images.dtype} cannot be a signed integer.") + if np.issubdtype(images.dtype, np.complexfloating): + raise ValueError(f"Input image dtype={images.dtype} cannot be complex.") + if np.issubdtype(images.dtype, bool): + raise ValueError(f"Input image dtype={images.dtype} cannot be boolean.") + + images = torch.from_numpy(images.transpose(0, 3, 1, 2)) + return images + + @staticmethod + def resize_antialias( + image: torch.Tensor, size: Tuple[int, int], mode: str, is_aa: Optional[bool] = None + ) -> torch.Tensor: + if not torch.is_tensor(image): + raise ValueError(f"Invalid input type={type(image)}.") + if not torch.is_floating_point(image): + raise ValueError(f"Invalid input dtype={image.dtype}.") + if image.dim() != 4: + raise ValueError(f"Invalid input dimensions; shape={image.shape}.") + + antialias = is_aa and mode in ("bilinear", "bicubic") + image = F.interpolate(image, size, mode=mode, antialias=antialias) + + return image + + @staticmethod + def resize_to_max_edge(image: torch.Tensor, max_edge_sz: int, mode: str) -> torch.Tensor: + if not torch.is_tensor(image): + raise ValueError(f"Invalid input type={type(image)}.") + if not torch.is_floating_point(image): + raise ValueError(f"Invalid input dtype={image.dtype}.") + if image.dim() != 4: + raise ValueError(f"Invalid input dimensions; shape={image.shape}.") + + h, w = image.shape[-2:] + max_orig = max(h, w) + new_h = h * max_edge_sz // max_orig + new_w = w * max_edge_sz // max_orig + + if new_h == 0 or new_w == 0: + raise ValueError(f"Extreme aspect ratio of the input image: [{w} x {h}]") + + image = MarigoldImageProcessor.resize_antialias(image, (new_h, new_w), mode, is_aa=True) + + return image + + @staticmethod + def pad_image(image: torch.Tensor, align: int) -> Tuple[torch.Tensor, Tuple[int, int]]: + if not torch.is_tensor(image): + raise ValueError(f"Invalid input type={type(image)}.") + if not torch.is_floating_point(image): + raise ValueError(f"Invalid input dtype={image.dtype}.") + if image.dim() != 4: + raise ValueError(f"Invalid input dimensions; shape={image.shape}.") + + h, w = image.shape[-2:] + ph, pw = -h % align, -w % align + + image = F.pad(image, (0, pw, 0, ph), mode="replicate") + + return image, (ph, pw) + + @staticmethod + def unpad_image(image: torch.Tensor, padding: Tuple[int, int]) -> torch.Tensor: + if not torch.is_tensor(image): + raise ValueError(f"Invalid input type={type(image)}.") + if not torch.is_floating_point(image): + raise ValueError(f"Invalid input dtype={image.dtype}.") + if image.dim() != 4: + raise ValueError(f"Invalid input dimensions; shape={image.shape}.") + + ph, pw = padding + uh = None if ph == 0 else -ph + uw = None if pw == 0 else -pw + + image = image[:, :, :uh, :uw] + + return image + + @staticmethod + def load_image_canonical( + image: Union[torch.Tensor, np.ndarray, Image.Image], + device: torch.device = torch.device("cpu"), + dtype: torch.dtype = torch.float32, + ) -> Tuple[torch.Tensor, int]: + if isinstance(image, Image.Image): + image = np.array(image) + + image_dtype_max = None + if isinstance(image, (np.ndarray, torch.Tensor)): + image = MarigoldImageProcessor.expand_tensor_or_array(image) + if image.ndim != 4: + raise ValueError("Input image is not 2-, 3-, or 4-dimensional.") + if isinstance(image, np.ndarray): + if np.issubdtype(image.dtype, np.integer) and not np.issubdtype(image.dtype, np.unsignedinteger): + raise ValueError(f"Input image dtype={image.dtype} cannot be a signed integer.") + if np.issubdtype(image.dtype, np.complexfloating): + raise ValueError(f"Input image dtype={image.dtype} cannot be complex.") + if np.issubdtype(image.dtype, bool): + raise ValueError(f"Input image dtype={image.dtype} cannot be boolean.") + if np.issubdtype(image.dtype, np.unsignedinteger): + image_dtype_max = np.iinfo(image.dtype).max + image = image.astype(np.float32) # because torch does not have unsigned dtypes beyond torch.uint8 + image = MarigoldImageProcessor.numpy_to_pt(image) + + if torch.is_tensor(image) and not torch.is_floating_point(image) and image_dtype_max is None: + if image.dtype != torch.uint8: + raise ValueError(f"Image dtype={image.dtype} is not supported.") + image_dtype_max = 255 + + if not torch.is_tensor(image): + raise ValueError(f"Input type unsupported: {type(image)}.") + + if image.shape[1] == 1: + image = image.repeat(1, 3, 1, 1) # [N,1,H,W] -> [N,3,H,W] + if image.shape[1] != 3: + raise ValueError(f"Input image is not 1- or 3-channel: {image.shape}.") + + image = image.to(device=device, dtype=dtype) + + if image_dtype_max is not None: + image = image / image_dtype_max + + return image + + @staticmethod + def check_image_values_range(image: torch.Tensor) -> None: + if not torch.is_tensor(image): + raise ValueError(f"Invalid input type={type(image)}.") + if not torch.is_floating_point(image): + raise ValueError(f"Invalid input dtype={image.dtype}.") + if image.min().item() < 0.0 or image.max().item() > 1.0: + raise ValueError("Input image data is partially outside of the [0,1] range.") + + def preprocess( + self, + image: PipelineImageInput, + processing_resolution: Optional[int] = None, + resample_method_input: str = "bilinear", + device: torch.device = torch.device("cpu"), + dtype: torch.dtype = torch.float32, + ): + if isinstance(image, list): + images = None + for i, img in enumerate(image): + img = self.load_image_canonical(img, device, dtype) # [N,3,H,W] + if images is None: + images = img + else: + if images.shape[2:] != img.shape[2:]: + raise ValueError( + f"Input image[{i}] has incompatible dimensions {img.shape[2:]} with the previous images " + f"{images.shape[2:]}" + ) + images = torch.cat((images, img), dim=0) + image = images + del images + else: + image = self.load_image_canonical(image, device, dtype) # [N,3,H,W] + + original_resolution = image.shape[2:] + + if self.config.do_range_check: + self.check_image_values_range(image) + + if self.config.do_normalize: + image = image * 2.0 - 1.0 + + if processing_resolution is not None and processing_resolution > 0: + image = self.resize_to_max_edge(image, processing_resolution, resample_method_input) # [N,3,PH,PW] + + image, padding = self.pad_image(image, self.config.vae_scale_factor) # [N,3,PPH,PPW] + + return image, padding, original_resolution + + @staticmethod + def colormap( + image: Union[np.ndarray, torch.Tensor], + cmap: str = "Spectral", + bytes: bool = False, + _force_method: Optional[str] = None, + ) -> Union[np.ndarray, torch.Tensor]: + """ + Converts a monochrome image into an RGB image by applying the specified colormap. This function mimics the + behavior of matplotlib.colormaps, but allows the user to use the most discriminative color map "Spectral" + without having to install or import matplotlib. For all other cases, the function will attempt to use the + native implementation. + + Args: + image: 2D tensor of values between 0 and 1, either as np.ndarray or torch.Tensor. + cmap: Colormap name. + bytes: Whether to return the output as uint8 or floating point image. + _force_method: + Can be used to specify whether to use the native implementation (`"matplotlib"`), the efficient custom + implementation of the "Spectral" color map (`"custom"`), or rely on autodetection (`None`, default). + + Returns: + An RGB-colorized tensor corresponding to the input image. + """ + if not (torch.is_tensor(image) or isinstance(image, np.ndarray)): + raise ValueError("Argument must be a numpy array or torch tensor.") + if _force_method not in (None, "matplotlib", "custom"): + raise ValueError("_force_method must be either `None`, `'matplotlib'` or `'custom'`.") + + def method_matplotlib(image, cmap, bytes=False): + if is_matplotlib_available(): + import matplotlib + else: + return None + + arg_is_pt, device = torch.is_tensor(image), None + if arg_is_pt: + image, device = image.cpu().numpy(), image.device + + if cmap not in matplotlib.colormaps: + raise ValueError( + f"Unexpected color map {cmap}; available options are: {', '.join(list(matplotlib.colormaps.keys()))}" + ) + + cmap = matplotlib.colormaps[cmap] + out = cmap(image, bytes=bytes) # [?,4] + out = out[..., :3] # [?,3] + + if arg_is_pt: + out = torch.tensor(out, device=device) + + return out + + def method_custom(image, cmap, bytes=False): + arg_is_np = isinstance(image, np.ndarray) + if arg_is_np: + image = torch.tensor(image) + if image.dtype == torch.uint8: + image = image.float() / 255 + else: + image = image.float() + + if cmap != "Spectral": + raise ValueError("Only 'Spectral' color map is available without installing matplotlib.") + + _Spectral_data = ( # Taken from matplotlib/_cm.py + (0.61960784313725492, 0.003921568627450980, 0.25882352941176473), # 0.0 -> [0] + (0.83529411764705885, 0.24313725490196078, 0.30980392156862746), + (0.95686274509803926, 0.42745098039215684, 0.2627450980392157), + (0.99215686274509807, 0.68235294117647061, 0.38039215686274508), + (0.99607843137254903, 0.8784313725490196, 0.54509803921568623), + (1.0, 1.0, 0.74901960784313726), + (0.90196078431372551, 0.96078431372549022, 0.59607843137254901), + (0.6705882352941176, 0.8666666666666667, 0.64313725490196083), + (0.4, 0.76078431372549016, 0.6470588235294118), + (0.19607843137254902, 0.53333333333333333, 0.74117647058823533), + (0.36862745098039218, 0.30980392156862746, 0.63529411764705879), # 1.0 -> [K-1] + ) + + cmap = torch.tensor(_Spectral_data, dtype=torch.float, device=image.device) # [K,3] + K = cmap.shape[0] + + pos = image.clamp(min=0, max=1) * (K - 1) + left = pos.long() + right = (left + 1).clamp(max=K - 1) + + d = (pos - left.float()).unsqueeze(-1) + left_colors = cmap[left] + right_colors = cmap[right] + + out = (1 - d) * left_colors + d * right_colors + + if bytes: + out = (out * 255).to(torch.uint8) + + if arg_is_np: + out = out.numpy() + + return out + + if _force_method is None and torch.is_tensor(image) and cmap == "Spectral": + return method_custom(image, cmap, bytes) + + out = None + if _force_method != "custom": + out = method_matplotlib(image, cmap, bytes) + + if _force_method == "matplotlib" and out is None: + raise ImportError("Make sure to install matplotlib if you want to use a color map other than 'Spectral'.") + + if out is None: + out = method_custom(image, cmap, bytes) + + return out + + @staticmethod + def visualize_depth( + depth: Union[ + PIL.Image.Image, + np.ndarray, + torch.Tensor, + List[PIL.Image.Image], + List[np.ndarray], + List[torch.Tensor], + ], + val_min: float = 0.0, + val_max: float = 1.0, + color_map: str = "Spectral", + ) -> Union[PIL.Image.Image, List[PIL.Image.Image]]: + """ + Visualizes depth maps, such as predictions of the `MarigoldDepthPipeline`. + + Args: + depth (`Union[PIL.Image.Image, np.ndarray, torch.Tensor, List[PIL.Image.Image], List[np.ndarray], + List[torch.Tensor]]`): Depth maps. + val_min (`float`, *optional*, defaults to `0.0`): Minimum value of the visualized depth range. + val_max (`float`, *optional*, defaults to `1.0`): Maximum value of the visualized depth range. + color_map (`str`, *optional*, defaults to `"Spectral"`): Color map used to convert a single-channel + depth prediction into colored representation. + + Returns: `PIL.Image.Image` or `List[PIL.Image.Image]` with depth maps visualization. + """ + if val_max <= val_min: + raise ValueError(f"Invalid values range: [{val_min}, {val_max}].") + + def visualize_depth_one(img, idx=None): + prefix = "Depth" + (f"[{idx}]" if idx else "") + if isinstance(img, PIL.Image.Image): + if img.mode != "I;16": + raise ValueError(f"{prefix}: invalid PIL mode={img.mode}.") + img = np.array(img).astype(np.float32) / (2**16 - 1) + if isinstance(img, np.ndarray) or torch.is_tensor(img): + if img.ndim != 2: + raise ValueError(f"{prefix}: unexpected shape={img.shape}.") + if isinstance(img, np.ndarray): + img = torch.from_numpy(img) + if not torch.is_floating_point(img): + raise ValueError(f"{prefix}: unexected dtype={img.dtype}.") + else: + raise ValueError(f"{prefix}: unexpected type={type(img)}.") + if val_min != 0.0 or val_max != 1.0: + img = (img - val_min) / (val_max - val_min) + img = MarigoldImageProcessor.colormap(img, cmap=color_map, bytes=True) # [H,W,3] + img = PIL.Image.fromarray(img.cpu().numpy()) + return img + + if depth is None or isinstance(depth, list) and any(o is None for o in depth): + raise ValueError("Input depth is `None`") + if isinstance(depth, (np.ndarray, torch.Tensor)): + depth = MarigoldImageProcessor.expand_tensor_or_array(depth) + if isinstance(depth, np.ndarray): + depth = MarigoldImageProcessor.numpy_to_pt(depth) # [N,H,W,1] -> [N,1,H,W] + if not (depth.ndim == 4 and depth.shape[1] == 1): # [N,1,H,W] + raise ValueError(f"Unexpected input shape={depth.shape}, expecting [N,1,H,W].") + return [visualize_depth_one(img[0], idx) for idx, img in enumerate(depth)] + elif isinstance(depth, list): + return [visualize_depth_one(img, idx) for idx, img in enumerate(depth)] + else: + raise ValueError(f"Unexpected input type: {type(depth)}") + + @staticmethod + def export_depth_to_16bit_png( + depth: Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]], + val_min: float = 0.0, + val_max: float = 1.0, + ) -> Union[PIL.Image.Image, List[PIL.Image.Image]]: + def export_depth_to_16bit_png_one(img, idx=None): + prefix = "Depth" + (f"[{idx}]" if idx else "") + if not isinstance(img, np.ndarray) and not torch.is_tensor(img): + raise ValueError(f"{prefix}: unexpected type={type(img)}.") + if img.ndim != 2: + raise ValueError(f"{prefix}: unexpected shape={img.shape}.") + if torch.is_tensor(img): + img = img.cpu().numpy() + if not np.issubdtype(img.dtype, np.floating): + raise ValueError(f"{prefix}: unexected dtype={img.dtype}.") + if val_min != 0.0 or val_max != 1.0: + img = (img - val_min) / (val_max - val_min) + img = (img * (2**16 - 1)).astype(np.uint16) + img = PIL.Image.fromarray(img, mode="I;16") + return img + + if depth is None or isinstance(depth, list) and any(o is None for o in depth): + raise ValueError("Input depth is `None`") + if isinstance(depth, (np.ndarray, torch.Tensor)): + depth = MarigoldImageProcessor.expand_tensor_or_array(depth) + if isinstance(depth, np.ndarray): + depth = MarigoldImageProcessor.numpy_to_pt(depth) # [N,H,W,1] -> [N,1,H,W] + if not (depth.ndim == 4 and depth.shape[1] == 1): + raise ValueError(f"Unexpected input shape={depth.shape}, expecting [N,1,H,W].") + return [export_depth_to_16bit_png_one(img[0], idx) for idx, img in enumerate(depth)] + elif isinstance(depth, list): + return [export_depth_to_16bit_png_one(img, idx) for idx, img in enumerate(depth)] + else: + raise ValueError(f"Unexpected input type: {type(depth)}") + + @staticmethod + def visualize_normals( + normals: Union[ + np.ndarray, + torch.Tensor, + List[np.ndarray], + List[torch.Tensor], + ], + flip_x: bool = False, + flip_y: bool = False, + flip_z: bool = False, + ) -> Union[PIL.Image.Image, List[PIL.Image.Image]]: + """ + Visualizes surface normals, such as predictions of the `MarigoldNormalsPipeline`. + + Args: + normals (`Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]]`): + Surface normals. + flip_x (`bool`, *optional*, defaults to `False`): Flips the X axis of the normals frame of reference. + Default direction is right. + flip_y (`bool`, *optional*, defaults to `False`): Flips the Y axis of the normals frame of reference. + Default direction is top. + flip_z (`bool`, *optional*, defaults to `False`): Flips the Z axis of the normals frame of reference. + Default direction is facing the observer. + + Returns: `PIL.Image.Image` or `List[PIL.Image.Image]` with surface normals visualization. + """ + flip_vec = None + if any((flip_x, flip_y, flip_z)): + flip_vec = torch.tensor( + [ + (-1) ** flip_x, + (-1) ** flip_y, + (-1) ** flip_z, + ], + dtype=torch.float32, + ) + + def visualize_normals_one(img, idx=None): + img = img.permute(1, 2, 0) + if flip_vec is not None: + img *= flip_vec.to(img.device) + img = (img + 1.0) * 0.5 + img = (img * 255).to(dtype=torch.uint8, device="cpu").numpy() + img = PIL.Image.fromarray(img) + return img + + if normals is None or isinstance(normals, list) and any(o is None for o in normals): + raise ValueError("Input normals is `None`") + if isinstance(normals, (np.ndarray, torch.Tensor)): + normals = MarigoldImageProcessor.expand_tensor_or_array(normals) + if isinstance(normals, np.ndarray): + normals = MarigoldImageProcessor.numpy_to_pt(normals) # [N,3,H,W] + if not (normals.ndim == 4 and normals.shape[1] == 3): + raise ValueError(f"Unexpected input shape={normals.shape}, expecting [N,3,H,W].") + return [visualize_normals_one(img, idx) for idx, img in enumerate(normals)] + elif isinstance(normals, list): + return [visualize_normals_one(img, idx) for idx, img in enumerate(normals)] + else: + raise ValueError(f"Unexpected input type: {type(normals)}") + + @staticmethod + def visualize_uncertainty( + uncertainty: Union[ + np.ndarray, + torch.Tensor, + List[np.ndarray], + List[torch.Tensor], + ], + saturation_percentile=95, + ) -> Union[PIL.Image.Image, List[PIL.Image.Image]]: + """ + Visualizes dense uncertainties, such as produced by `MarigoldDepthPipeline` or `MarigoldNormalsPipeline`. + + Args: + uncertainty (`Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]]`): + Uncertainty maps. + saturation_percentile (`int`, *optional*, defaults to `95`): + Specifies the percentile uncertainty value visualized with maximum intensity. + + Returns: `PIL.Image.Image` or `List[PIL.Image.Image]` with uncertainty visualization. + """ + + def visualize_uncertainty_one(img, idx=None): + prefix = "Uncertainty" + (f"[{idx}]" if idx else "") + if img.min() < 0: + raise ValueError(f"{prefix}: unexected data range, min={img.min()}.") + img = img.squeeze(0).cpu().numpy() + saturation_value = np.percentile(img, saturation_percentile) + img = np.clip(img * 255 / saturation_value, 0, 255) + img = img.astype(np.uint8) + img = PIL.Image.fromarray(img) + return img + + if uncertainty is None or isinstance(uncertainty, list) and any(o is None for o in uncertainty): + raise ValueError("Input uncertainty is `None`") + if isinstance(uncertainty, (np.ndarray, torch.Tensor)): + uncertainty = MarigoldImageProcessor.expand_tensor_or_array(uncertainty) + if isinstance(uncertainty, np.ndarray): + uncertainty = MarigoldImageProcessor.numpy_to_pt(uncertainty) # [N,1,H,W] + if not (uncertainty.ndim == 4 and uncertainty.shape[1] == 1): + raise ValueError(f"Unexpected input shape={uncertainty.shape}, expecting [N,1,H,W].") + return [visualize_uncertainty_one(img, idx) for idx, img in enumerate(uncertainty)] + elif isinstance(uncertainty, list): + return [visualize_uncertainty_one(img, idx) for idx, img in enumerate(uncertainty)] + else: + raise ValueError(f"Unexpected input type: {type(uncertainty)}") diff --git a/src/diffusers/pipelines/marigold/pipeline_marigold_depth.py b/src/diffusers/pipelines/marigold/pipeline_marigold_depth.py new file mode 100644 index 000000000000..a602ba611ea5 --- /dev/null +++ b/src/diffusers/pipelines/marigold/pipeline_marigold_depth.py @@ -0,0 +1,813 @@ +# Copyright 2024 Marigold authors, PRS ETH Zurich. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information and citation instructions are available on the +# Marigold project website: https://marigoldmonodepth.github.io +# -------------------------------------------------------------------------- +from dataclasses import dataclass +from functools import partial +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from PIL import Image +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer + +from ...image_processor import PipelineImageInput +from ...models import ( + AutoencoderKL, + UNet2DConditionModel, +) +from ...schedulers import ( + DDIMScheduler, + LCMScheduler, +) +from ...utils import ( + BaseOutput, + logging, + replace_example_docstring, +) +from ...utils.import_utils import is_scipy_available +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .marigold_image_processing import MarigoldImageProcessor + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ +Examples: +```py +>>> import diffusers +>>> import torch + +>>> pipe = diffusers.MarigoldDepthPipeline.from_pretrained( +... "prs-eth/marigold-depth-lcm-v1-0", variant="fp16", torch_dtype=torch.float16 +... ).to("cuda") + +>>> image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg") +>>> depth = pipe(image) + +>>> vis = pipe.image_processor.visualize_depth(depth.prediction) +>>> vis[0].save("einstein_depth.png") + +>>> depth_16bit = pipe.image_processor.export_depth_to_16bit_png(depth.prediction) +>>> depth_16bit[0].save("einstein_depth_16bit.png") +``` +""" + + +@dataclass +class MarigoldDepthOutput(BaseOutput): + """ + Output class for Marigold monocular depth prediction pipeline. + + Args: + prediction (`np.ndarray`, `torch.Tensor`): + Predicted depth maps with values in the range [0, 1]. The shape is always $numimages \times 1 \times height + \times width$, regardless of whether the images were passed as a 4D array or a list. + uncertainty (`None`, `np.ndarray`, `torch.Tensor`): + Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is $numimages + \times 1 \times height \times width$. + latent (`None`, `torch.Tensor`): + Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline. + The shape is $numimages * numensemble \times 4 \times latentheight \times latentwidth$. + """ + + prediction: Union[np.ndarray, torch.Tensor] + uncertainty: Union[None, np.ndarray, torch.Tensor] + latent: Union[None, torch.Tensor] + + +class MarigoldDepthPipeline(DiffusionPipeline): + """ + Pipeline for monocular depth estimation using the Marigold method: https://marigoldmonodepth.github.io. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + unet (`UNet2DConditionModel`): + Conditional U-Net to denoise the depth latent, conditioned on image latent. + vae (`AutoencoderKL`): + Variational Auto-Encoder (VAE) Model to encode and decode images and predictions to and from latent + representations. + scheduler (`DDIMScheduler` or `LCMScheduler`): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. + text_encoder (`CLIPTextModel`): + Text-encoder, for empty text embedding. + tokenizer (`CLIPTokenizer`): + CLIP tokenizer. + prediction_type (`str`, *optional*): + Type of predictions made by the model. + scale_invariant (`bool`, *optional*): + A model property specifying whether the predicted depth maps are scale-invariant. This value must be set in + the model config. When used together with the `shift_invariant=True` flag, the model is also called + "affine-invariant". NB: overriding this value is not supported. + shift_invariant (`bool`, *optional*): + A model property specifying whether the predicted depth maps are shift-invariant. This value must be set in + the model config. When used together with the `scale_invariant=True` flag, the model is also called + "affine-invariant". NB: overriding this value is not supported. + default_denoising_steps (`int`, *optional*): + The minimum number of denoising diffusion steps that are required to produce a prediction of reasonable + quality with the given model. This value must be set in the model config. When the pipeline is called + without explicitly setting `num_inference_steps`, the default value is used. This is required to ensure + reasonable results with various model flavors compatible with the pipeline, such as those relying on very + short denoising schedules (`LCMScheduler`) and those with full diffusion schedules (`DDIMScheduler`). + default_processing_resolution (`int`, *optional*): + The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in + the model config. When the pipeline is called without explicitly setting `processing_resolution`, the + default value is used. This is required to ensure reasonable results with various model flavors trained + with varying optimal processing resolution values. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + supported_prediction_types = ("depth", "disparity") + + def __init__( + self, + unet: UNet2DConditionModel, + vae: AutoencoderKL, + scheduler: Union[DDIMScheduler, LCMScheduler], + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + prediction_type: Optional[str] = None, + scale_invariant: Optional[bool] = True, + shift_invariant: Optional[bool] = True, + default_denoising_steps: Optional[int] = None, + default_processing_resolution: Optional[int] = None, + ): + super().__init__() + + if prediction_type not in self.supported_prediction_types: + logger.warning( + f"Potentially unsupported `prediction_type='{prediction_type}'`; values supported by the pipeline: " + f"{self.supported_prediction_types}." + ) + + self.register_modules( + unet=unet, + vae=vae, + scheduler=scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + ) + self.register_to_config( + prediction_type=prediction_type, + scale_invariant=scale_invariant, + shift_invariant=shift_invariant, + default_denoising_steps=default_denoising_steps, + default_processing_resolution=default_processing_resolution, + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + self.scale_invariant = scale_invariant + self.shift_invariant = shift_invariant + self.default_denoising_steps = default_denoising_steps + self.default_processing_resolution = default_processing_resolution + + self.empty_text_embedding = None + + self.image_processor = MarigoldImageProcessor(vae_scale_factor=self.vae_scale_factor) + + def check_inputs( + self, + image: PipelineImageInput, + num_inference_steps: int, + ensemble_size: int, + processing_resolution: int, + resample_method_input: str, + resample_method_output: str, + batch_size: int, + ensembling_kwargs: Optional[Dict[str, Any]], + latents: Optional[torch.Tensor], + generator: Optional[Union[torch.Generator, List[torch.Generator]]], + output_type: str, + output_uncertainty: bool, + ) -> int: + if num_inference_steps is None: + raise ValueError("`num_inference_steps` is not specified and could not be resolved from the model config.") + if num_inference_steps < 1: + raise ValueError("`num_inference_steps` must be positive.") + if ensemble_size < 1: + raise ValueError("`ensemble_size` must be positive.") + if ensemble_size == 2: + logger.warning( + "`ensemble_size` == 2 results are similar to no ensembling (1); " + "consider increasing the value to at least 3." + ) + if ensemble_size > 1 and (self.scale_invariant or self.shift_invariant) and not is_scipy_available(): + raise ImportError("Make sure to install scipy if you want to use ensembling.") + if ensemble_size == 1 and output_uncertainty: + raise ValueError( + "Computing uncertainty by setting `output_uncertainty=True` also requires setting `ensemble_size` " + "greater than 1." + ) + if processing_resolution is None: + raise ValueError( + "`processing_resolution` is not specified and could not be resolved from the model config." + ) + if processing_resolution < 0: + raise ValueError( + "`processing_resolution` must be non-negative: 0 for native resolution, or any positive value for " + "downsampled processing." + ) + if processing_resolution % self.vae_scale_factor != 0: + raise ValueError(f"`processing_resolution` must be a multiple of {self.vae_scale_factor}.") + if resample_method_input not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"): + raise ValueError( + "`resample_method_input` takes string values compatible with PIL library: " + "nearest, nearest-exact, bilinear, bicubic, area." + ) + if resample_method_output not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"): + raise ValueError( + "`resample_method_output` takes string values compatible with PIL library: " + "nearest, nearest-exact, bilinear, bicubic, area." + ) + if batch_size < 1: + raise ValueError("`batch_size` must be positive.") + if output_type not in ["pt", "np"]: + raise ValueError("`output_type` must be one of `pt` or `np`.") + if latents is not None and generator is not None: + raise ValueError("`latents` and `generator` cannot be used together.") + if ensembling_kwargs is not None: + if not isinstance(ensembling_kwargs, dict): + raise ValueError("`ensembling_kwargs` must be a dictionary.") + if "reduction" in ensembling_kwargs and ensembling_kwargs["reduction"] not in ("mean", "median"): + raise ValueError("`ensembling_kwargs['reduction']` can be either `'mean'` or `'median'`.") + + # image checks + num_images = 0 + W, H = None, None + if not isinstance(image, list): + image = [image] + for i, img in enumerate(image): + if isinstance(img, np.ndarray) or torch.is_tensor(img): + if img.ndim not in (2, 3, 4): + raise ValueError(f"`image[{i}]` has unsupported dimensions or shape: {img.shape}.") + H_i, W_i = img.shape[-2:] + N_i = 1 + if img.ndim == 4: + N_i = img.shape[0] + elif isinstance(img, Image.Image): + W_i, H_i = img.size + N_i = 1 + else: + raise ValueError(f"Unsupported `image[{i}]` type: {type(img)}.") + if W is None: + W, H = W_i, H_i + elif (W, H) != (W_i, H_i): + raise ValueError( + f"Input `image[{i}]` has incompatible dimensions {(W_i, H_i)} with the previous images {(W, H)}" + ) + num_images += N_i + + # latents checks + if latents is not None: + if not torch.is_tensor(latents): + raise ValueError("`latents` must be a torch.Tensor.") + if latents.dim() != 4: + raise ValueError(f"`latents` has unsupported dimensions or shape: {latents.shape}.") + + if processing_resolution > 0: + max_orig = max(H, W) + new_H = H * processing_resolution // max_orig + new_W = W * processing_resolution // max_orig + if new_H == 0 or new_W == 0: + raise ValueError(f"Extreme aspect ratio of the input image: [{W} x {H}]") + W, H = new_W, new_H + w = (W + self.vae_scale_factor - 1) // self.vae_scale_factor + h = (H + self.vae_scale_factor - 1) // self.vae_scale_factor + shape_expected = (num_images * ensemble_size, self.vae.config.latent_channels, h, w) + + if latents.shape != shape_expected: + raise ValueError(f"`latents` has unexpected shape={latents.shape} expected={shape_expected}.") + + # generator checks + if generator is not None: + if isinstance(generator, list): + if len(generator) != num_images * ensemble_size: + raise ValueError( + "The number of generators must match the total number of ensemble members for all input images." + ) + if not all(g.device.type == generator[0].device.type for g in generator): + raise ValueError("`generator` device placement is not consistent in the list.") + elif not isinstance(generator, torch.Generator): + raise ValueError(f"Unsupported generator type: {type(generator)}.") + + return num_images + + def progress_bar(self, iterable=None, total=None, desc=None, leave=True): + if not hasattr(self, "_progress_bar_config"): + self._progress_bar_config = {} + elif not isinstance(self._progress_bar_config, dict): + raise ValueError( + f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." + ) + + progress_bar_config = dict(**self._progress_bar_config) + progress_bar_config["desc"] = progress_bar_config.get("desc", desc) + progress_bar_config["leave"] = progress_bar_config.get("leave", leave) + if iterable is not None: + return tqdm(iterable, **progress_bar_config) + elif total is not None: + return tqdm(total=total, **progress_bar_config) + else: + raise ValueError("Either `total` or `iterable` has to be defined.") + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput, + num_inference_steps: Optional[int] = None, + ensemble_size: int = 1, + processing_resolution: Optional[int] = None, + match_input_resolution: bool = True, + resample_method_input: str = "bilinear", + resample_method_output: str = "bilinear", + batch_size: int = 1, + ensembling_kwargs: Optional[Dict[str, Any]] = None, + latents: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: str = "np", + output_uncertainty: bool = False, + output_latent: bool = False, + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline. + + Args: + image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`), + `List[torch.Tensor]`: An input image or images used as an input for the depth estimation task. For + arrays and tensors, the expected value range is between `[0, 1]`. Passing a batch of images is possible + by providing a four-dimensional array or a tensor. Additionally, a list of images of two- or + three-dimensional arrays or tensors can be passed. In the latter case, all list elements must have the + same width and height. + num_inference_steps (`int`, *optional*, defaults to `None`): + Number of denoising diffusion steps during inference. The default value `None` results in automatic + selection. The number of steps should be at least 10 with the full Marigold models, and between 1 and 4 + for Marigold-LCM models. + ensemble_size (`int`, defaults to `1`): + Number of ensemble predictions. Recommended values are 5 and higher for better precision, or 1 for + faster inference. + processing_resolution (`int`, *optional*, defaults to `None`): + Effective processing resolution. When set to `0`, matches the larger input image dimension. This + produces crisper predictions, but may also lead to the overall loss of global context. The default + value `None` resolves to the optimal value from the model config. + match_input_resolution (`bool`, *optional*, defaults to `True`): + When enabled, the output prediction is resized to match the input dimensions. When disabled, the longer + side of the output will equal to `processing_resolution`. + resample_method_input (`str`, *optional*, defaults to `"bilinear"`): + Resampling method used to resize input images to `processing_resolution`. The accepted values are: + `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`. + resample_method_output (`str`, *optional*, defaults to `"bilinear"`): + Resampling method used to resize output predictions to match the input resolution. The accepted values + are `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`. + batch_size (`int`, *optional*, defaults to `1`): + Batch size; only matters when setting `ensemble_size` or passing a tensor of images. + ensembling_kwargs (`dict`, *optional*, defaults to `None`) + Extra dictionary with arguments for precise ensembling control. The following options are available: + - reduction (`str`, *optional*, defaults to `"median"`): Defines the ensembling function applied in + every pixel location, can be either `"median"` or `"mean"`. + - regularizer_strength (`float`, *optional*, defaults to `0.02`): Strength of the regularizer that + pulls the aligned predictions to the unit range from 0 to 1. + - max_iter (`int`, *optional*, defaults to `2`): Maximum number of the alignment solver steps. Refer to + `scipy.optimize.minimize` function, `options` argument. + - tol (`float`, *optional*, defaults to `1e-3`): Alignment solver tolerance. The solver stops when the + tolerance is reached. + - max_res (`int`, *optional*, defaults to `None`): Resolution at which the alignment is performed; + `None` matches the `processing_resolution`. + latents (`torch.Tensor`, or `List[torch.Tensor]`, *optional*, defaults to `None`): + Latent noise tensors to replace the random initialization. These can be taken from the previous + function call's output. + generator (`torch.Generator`, or `List[torch.Generator]`, *optional*, defaults to `None`): + Random number generator object to ensure reproducibility. + output_type (`str`, *optional*, defaults to `"np"`): + Preferred format of the output's `prediction` and the optional `uncertainty` fields. The accepted + values are: `"np"` (numpy array) or `"pt"` (torch tensor). + output_uncertainty (`bool`, *optional*, defaults to `False`): + When enabled, the output's `uncertainty` field contains the predictive uncertainty map, provided that + the `ensemble_size` argument is set to a value above 2. + output_latent (`bool`, *optional*, defaults to `False`): + When enabled, the output's `latent` field contains the latent codes corresponding to the predictions + within the ensemble. These codes can be saved, modified, and used for subsequent calls with the + `latents` argument. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.marigold.MarigoldDepthOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`~pipelines.marigold.MarigoldDepthOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.marigold.MarigoldDepthOutput`] is returned, otherwise a + `tuple` is returned where the first element is the prediction, the second element is the uncertainty + (or `None`), and the third is the latent (or `None`). + """ + + # 0. Resolving variables. + device = self._execution_device + dtype = self.dtype + + # Model-specific optimal default values leading to fast and reasonable results. + if num_inference_steps is None: + num_inference_steps = self.default_denoising_steps + if processing_resolution is None: + processing_resolution = self.default_processing_resolution + + # 1. Check inputs. + num_images = self.check_inputs( + image, + num_inference_steps, + ensemble_size, + processing_resolution, + resample_method_input, + resample_method_output, + batch_size, + ensembling_kwargs, + latents, + generator, + output_type, + output_uncertainty, + ) + + # 2. Prepare empty text conditioning. + # Model invocation: self.tokenizer, self.text_encoder. + if self.empty_text_embedding is None: + prompt = "" + text_inputs = self.tokenizer( + prompt, + padding="do_not_pad", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids.to(device) + self.empty_text_embedding = self.text_encoder(text_input_ids)[0] # [1,2,1024] + + # 3. Preprocess input images. This function loads input image or images of compatible dimensions `(H, W)`, + # optionally downsamples them to the `processing_resolution` `(PH, PW)`, where + # `max(PH, PW) == processing_resolution`, and pads the dimensions to `(PPH, PPW)` such that these values are + # divisible by the latent space downscaling factor (typically 8 in Stable Diffusion). The default value `None` + # of `processing_resolution` resolves to the optimal value from the model config. It is a recommended mode of + # operation and leads to the most reasonable results. Using the native image resolution or any other processing + # resolution can lead to loss of either fine details or global context in the output predictions. + image, padding, original_resolution = self.image_processor.preprocess( + image, processing_resolution, resample_method_input, device, dtype + ) # [N,3,PPH,PPW] + + # 4. Encode input image into latent space. At this step, each of the `N` input images is represented with `E` + # ensemble members. Each ensemble member is an independent diffused prediction, just initialized independently. + # Latents of each such predictions across all input images and all ensemble members are represented in the + # `pred_latent` variable. The variable `image_latent` is of the same shape: it contains each input image encoded + # into latent space and replicated `E` times. The latents can be either generated (see `generator` to ensure + # reproducibility), or passed explicitly via the `latents` argument. The latter can be set outside the pipeline + # code. For example, in the Marigold-LCM video processing demo, the latents initialization of a frame is taken + # as a convex combination of the latents output of the pipeline for the previous frame and a newly-sampled + # noise. This behavior can be achieved by setting the `output_latent` argument to `True`. The latent space + # dimensions are `(h, w)`. Encoding into latent space happens in batches of size `batch_size`. + # Model invocation: self.vae.encoder. + image_latent, pred_latent = self.prepare_latents( + image, latents, generator, ensemble_size, batch_size + ) # [N*E,4,h,w], [N*E,4,h,w] + + del image + + batch_empty_text_embedding = self.empty_text_embedding.to(device=device, dtype=dtype).repeat( + batch_size, 1, 1 + ) # [B,1024,2] + + # 5. Process the denoising loop. All `N * E` latents are processed sequentially in batches of size `batch_size`. + # The unet model takes concatenated latent spaces of the input image and the predicted modality as an input, and + # outputs noise for the predicted modality's latent space. The number of denoising diffusion steps is defined by + # `num_inference_steps`. It is either set directly, or resolves to the optimal value specific to the loaded + # model. + # Model invocation: self.unet. + pred_latents = [] + + for i in self.progress_bar( + range(0, num_images * ensemble_size, batch_size), leave=True, desc="Marigold predictions..." + ): + batch_image_latent = image_latent[i : i + batch_size] # [B,4,h,w] + batch_pred_latent = pred_latent[i : i + batch_size] # [B,4,h,w] + effective_batch_size = batch_image_latent.shape[0] + text = batch_empty_text_embedding[:effective_batch_size] # [B,2,1024] + + self.scheduler.set_timesteps(num_inference_steps, device=device) + for t in self.progress_bar(self.scheduler.timesteps, leave=False, desc="Diffusion steps..."): + batch_latent = torch.cat([batch_image_latent, batch_pred_latent], dim=1) # [B,8,h,w] + noise = self.unet(batch_latent, t, encoder_hidden_states=text, return_dict=False)[0] # [B,4,h,w] + batch_pred_latent = self.scheduler.step( + noise, t, batch_pred_latent, generator=generator + ).prev_sample # [B,4,h,w] + + pred_latents.append(batch_pred_latent) + + pred_latent = torch.cat(pred_latents, dim=0) # [N*E,4,h,w] + + del ( + pred_latents, + image_latent, + batch_empty_text_embedding, + batch_image_latent, + batch_pred_latent, + text, + batch_latent, + noise, + ) + + # 6. Decode predictions from latent into pixel space. The resulting `N * E` predictions have shape `(PPH, PPW)`, + # which requires slight postprocessing. Decoding into pixel space happens in batches of size `batch_size`. + # Model invocation: self.vae.decoder. + prediction = torch.cat( + [ + self.decode_prediction(pred_latent[i : i + batch_size]) + for i in range(0, pred_latent.shape[0], batch_size) + ], + dim=0, + ) # [N*E,1,PPH,PPW] + + if not output_latent: + pred_latent = None + + # 7. Remove padding. The output shape is (PH, PW). + prediction = self.image_processor.unpad_image(prediction, padding) # [N*E,1,PH,PW] + + # 8. Ensemble and compute uncertainty (when `output_uncertainty` is set). This code treats each of the `N` + # groups of `E` ensemble predictions independently. For each group it computes an ensembled prediction of shape + # `(PH, PW)` and an optional uncertainty map of the same dimensions. After computing this pair of outputs for + # each group independently, it stacks them respectively into batches of `N` almost final predictions and + # uncertainty maps. + uncertainty = None + if ensemble_size > 1: + prediction = prediction.reshape(num_images, ensemble_size, *prediction.shape[1:]) # [N,E,1,PH,PW] + prediction = [ + self.ensemble_depth( + prediction[i], + self.scale_invariant, + self.shift_invariant, + output_uncertainty, + **(ensembling_kwargs or {}), + ) + for i in range(num_images) + ] # [ [[1,1,PH,PW], [1,1,PH,PW]], ... ] + prediction, uncertainty = zip(*prediction) # [[1,1,PH,PW], ... ], [[1,1,PH,PW], ... ] + prediction = torch.cat(prediction, dim=0) # [N,1,PH,PW] + if output_uncertainty: + uncertainty = torch.cat(uncertainty, dim=0) # [N,1,PH,PW] + else: + uncertainty = None + + # 9. If `match_input_resolution` is set, the output prediction and the uncertainty are upsampled to match the + # input resolution `(H, W)`. This step may introduce upsampling artifacts, and therefore can be disabled. + # Depending on the downstream use-case, upsampling can be also chosen based on the tolerated artifacts by + # setting the `resample_method_output` parameter (e.g., to `"nearest"`). + if match_input_resolution: + prediction = self.image_processor.resize_antialias( + prediction, original_resolution, resample_method_output, is_aa=False + ) # [N,1,H,W] + if uncertainty is not None and output_uncertainty: + uncertainty = self.image_processor.resize_antialias( + uncertainty, original_resolution, resample_method_output, is_aa=False + ) # [N,1,H,W] + + # 10. Prepare the final outputs. + if output_type == "np": + prediction = self.image_processor.pt_to_numpy(prediction) # [N,H,W,1] + if uncertainty is not None and output_uncertainty: + uncertainty = self.image_processor.pt_to_numpy(uncertainty) # [N,H,W,1] + + # 11. Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (prediction, uncertainty, pred_latent) + + return MarigoldDepthOutput( + prediction=prediction, + uncertainty=uncertainty, + latent=pred_latent, + ) + + def prepare_latents( + self, + image: torch.Tensor, + latents: Optional[torch.Tensor], + generator: Optional[torch.Generator], + ensemble_size: int, + batch_size: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + def retrieve_latents(encoder_output): + if hasattr(encoder_output, "latent_dist"): + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + image_latent = torch.cat( + [ + retrieve_latents(self.vae.encode(image[i : i + batch_size])) + for i in range(0, image.shape[0], batch_size) + ], + dim=0, + ) # [N,4,h,w] + image_latent = image_latent * self.vae.config.scaling_factor + image_latent = image_latent.repeat_interleave(ensemble_size, dim=0) # [N*E,4,h,w] + + pred_latent = latents + if pred_latent is None: + pred_latent = randn_tensor( + image_latent.shape, + generator=generator, + device=image_latent.device, + dtype=image_latent.dtype, + ) # [N*E,4,h,w] + + return image_latent, pred_latent + + def decode_prediction(self, pred_latent: torch.Tensor) -> torch.Tensor: + if pred_latent.dim() != 4 or pred_latent.shape[1] != self.vae.config.latent_channels: + raise ValueError( + f"Expecting 4D tensor of shape [B,{self.vae.config.latent_channels},H,W]; got {pred_latent.shape}." + ) + + prediction = self.vae.decode(pred_latent / self.vae.config.scaling_factor, return_dict=False)[0] # [B,3,H,W] + + prediction = prediction.mean(dim=1, keepdim=True) # [B,1,H,W] + prediction = torch.clip(prediction, -1.0, 1.0) # [B,1,H,W] + prediction = (prediction + 1.0) / 2.0 + + return prediction # [B,1,H,W] + + @staticmethod + def ensemble_depth( + depth: torch.Tensor, + scale_invariant: bool = True, + shift_invariant: bool = True, + output_uncertainty: bool = False, + reduction: str = "median", + regularizer_strength: float = 0.02, + max_iter: int = 2, + tol: float = 1e-3, + max_res: int = 1024, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Ensembles the depth maps represented by the `depth` tensor with expected shape `(B, 1, H, W)`, where B is the + number of ensemble members for a given prediction of size `(H x W)`. Even though the function is designed for + depth maps, it can also be used with disparity maps as long as the input tensor values are non-negative. The + alignment happens when the predictions have one or more degrees of freedom, that is when they are either + affine-invariant (`scale_invariant=True` and `shift_invariant=True`), or just scale-invariant (only + `scale_invariant=True`). For absolute predictions (`scale_invariant=False` and `shift_invariant=False`) + alignment is skipped and only ensembling is performed. + + Args: + depth (`torch.Tensor`): + Input ensemble depth maps. + scale_invariant (`bool`, *optional*, defaults to `True`): + Whether to treat predictions as scale-invariant. + shift_invariant (`bool`, *optional*, defaults to `True`): + Whether to treat predictions as shift-invariant. + output_uncertainty (`bool`, *optional*, defaults to `False`): + Whether to output uncertainty map. + reduction (`str`, *optional*, defaults to `"median"`): + Reduction method used to ensemble aligned predictions. The accepted values are: `"mean"` and + `"median"`. + regularizer_strength (`float`, *optional*, defaults to `0.02`): + Strength of the regularizer that pulls the aligned predictions to the unit range from 0 to 1. + max_iter (`int`, *optional*, defaults to `2`): + Maximum number of the alignment solver steps. Refer to `scipy.optimize.minimize` function, `options` + argument. + tol (`float`, *optional*, defaults to `1e-3`): + Alignment solver tolerance. The solver stops when the tolerance is reached. + max_res (`int`, *optional*, defaults to `1024`): + Resolution at which the alignment is performed; `None` matches the `processing_resolution`. + Returns: + A tensor of aligned and ensembled depth maps and optionally a tensor of uncertainties of the same shape: + `(1, 1, H, W)`. + """ + if depth.dim() != 4 or depth.shape[1] != 1: + raise ValueError(f"Expecting 4D tensor of shape [B,1,H,W]; got {depth.shape}.") + if reduction not in ("mean", "median"): + raise ValueError(f"Unrecognized reduction method: {reduction}.") + if not scale_invariant and shift_invariant: + raise ValueError("Pure shift-invariant ensembling is not supported.") + + def init_param(depth: torch.Tensor): + init_min = depth.reshape(ensemble_size, -1).min(dim=1).values + init_max = depth.reshape(ensemble_size, -1).max(dim=1).values + + if scale_invariant and shift_invariant: + init_s = 1.0 / (init_max - init_min).clamp(min=1e-6) + init_t = -init_s * init_min + param = torch.cat((init_s, init_t)).cpu().numpy() + elif scale_invariant: + init_s = 1.0 / init_max.clamp(min=1e-6) + param = init_s.cpu().numpy() + else: + raise ValueError("Unrecognized alignment.") + + return param + + def align(depth: torch.Tensor, param: np.ndarray) -> torch.Tensor: + if scale_invariant and shift_invariant: + s, t = np.split(param, 2) + s = torch.from_numpy(s).to(depth).view(ensemble_size, 1, 1, 1) + t = torch.from_numpy(t).to(depth).view(ensemble_size, 1, 1, 1) + out = depth * s + t + elif scale_invariant: + s = torch.from_numpy(param).to(depth).view(ensemble_size, 1, 1, 1) + out = depth * s + else: + raise ValueError("Unrecognized alignment.") + return out + + def ensemble( + depth_aligned: torch.Tensor, return_uncertainty: bool = False + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + uncertainty = None + if reduction == "mean": + prediction = torch.mean(depth_aligned, dim=0, keepdim=True) + if return_uncertainty: + uncertainty = torch.std(depth_aligned, dim=0, keepdim=True) + elif reduction == "median": + prediction = torch.median(depth_aligned, dim=0, keepdim=True).values + if return_uncertainty: + uncertainty = torch.median(torch.abs(depth_aligned - prediction), dim=0, keepdim=True).values + else: + raise ValueError(f"Unrecognized reduction method: {reduction}.") + return prediction, uncertainty + + def cost_fn(param: np.ndarray, depth: torch.Tensor) -> float: + cost = 0.0 + depth_aligned = align(depth, param) + + for i, j in torch.combinations(torch.arange(ensemble_size)): + diff = depth_aligned[i] - depth_aligned[j] + cost += (diff**2).mean().sqrt().item() + + if regularizer_strength > 0: + prediction, _ = ensemble(depth_aligned, return_uncertainty=False) + err_near = (0.0 - prediction.min()).abs().item() + err_far = (1.0 - prediction.max()).abs().item() + cost += (err_near + err_far) * regularizer_strength + + return cost + + def compute_param(depth: torch.Tensor): + import scipy + + depth_to_align = depth.to(torch.float32) + if max_res is not None and max(depth_to_align.shape[2:]) > max_res: + depth_to_align = MarigoldImageProcessor.resize_to_max_edge(depth_to_align, max_res, "nearest-exact") + + param = init_param(depth_to_align) + + res = scipy.optimize.minimize( + partial(cost_fn, depth=depth_to_align), + param, + method="BFGS", + tol=tol, + options={"maxiter": max_iter, "disp": False}, + ) + + return res.x + + requires_aligning = scale_invariant or shift_invariant + ensemble_size = depth.shape[0] + + if requires_aligning: + param = compute_param(depth) + depth = align(depth, param) + + depth, uncertainty = ensemble(depth, return_uncertainty=output_uncertainty) + + depth_max = depth.max() + if scale_invariant and shift_invariant: + depth_min = depth.min() + elif scale_invariant: + depth_min = 0 + else: + raise ValueError("Unrecognized alignment.") + depth_range = (depth_max - depth_min).clamp(min=1e-6) + depth = (depth - depth_min) / depth_range + if output_uncertainty: + uncertainty /= depth_range + + return depth, uncertainty # [1,1,H,W], [1,1,H,W] diff --git a/src/diffusers/pipelines/marigold/pipeline_marigold_normals.py b/src/diffusers/pipelines/marigold/pipeline_marigold_normals.py new file mode 100644 index 000000000000..aa9ad36ffc35 --- /dev/null +++ b/src/diffusers/pipelines/marigold/pipeline_marigold_normals.py @@ -0,0 +1,690 @@ +# Copyright 2024 Marigold authors, PRS ETH Zurich. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information and citation instructions are available on the +# Marigold project website: https://marigoldmonodepth.github.io +# -------------------------------------------------------------------------- +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from PIL import Image +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer + +from ...image_processor import PipelineImageInput +from ...models import ( + AutoencoderKL, + UNet2DConditionModel, +) +from ...schedulers import ( + DDIMScheduler, + LCMScheduler, +) +from ...utils import ( + BaseOutput, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .marigold_image_processing import MarigoldImageProcessor + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ +Examples: +```py +>>> import diffusers +>>> import torch + +>>> pipe = diffusers.MarigoldNormalsPipeline.from_pretrained( +... "prs-eth/marigold-normals-lcm-v0-1", variant="fp16", torch_dtype=torch.float16 +... ).to("cuda") + +>>> image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg") +>>> normals = pipe(image) + +>>> vis = pipe.image_processor.visualize_normals(normals.prediction) +>>> vis[0].save("einstein_normals.png") +``` +""" + + +@dataclass +class MarigoldNormalsOutput(BaseOutput): + """ + Output class for Marigold monocular normals prediction pipeline. + + Args: + prediction (`np.ndarray`, `torch.Tensor`): + Predicted normals with values in the range [-1, 1]. The shape is always $numimages \times 3 \times height + \times width$, regardless of whether the images were passed as a 4D array or a list. + uncertainty (`None`, `np.ndarray`, `torch.Tensor`): + Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is $numimages + \times 1 \times height \times width$. + latent (`None`, `torch.Tensor`): + Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline. + The shape is $numimages * numensemble \times 4 \times latentheight \times latentwidth$. + """ + + prediction: Union[np.ndarray, torch.Tensor] + uncertainty: Union[None, np.ndarray, torch.Tensor] + latent: Union[None, torch.Tensor] + + +class MarigoldNormalsPipeline(DiffusionPipeline): + """ + Pipeline for monocular normals estimation using the Marigold method: https://marigoldmonodepth.github.io. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + unet (`UNet2DConditionModel`): + Conditional U-Net to denoise the normals latent, conditioned on image latent. + vae (`AutoencoderKL`): + Variational Auto-Encoder (VAE) Model to encode and decode images and predictions to and from latent + representations. + scheduler (`DDIMScheduler` or `LCMScheduler`): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. + text_encoder (`CLIPTextModel`): + Text-encoder, for empty text embedding. + tokenizer (`CLIPTokenizer`): + CLIP tokenizer. + prediction_type (`str`, *optional*): + Type of predictions made by the model. + use_full_z_range (`bool`, *optional*): + Whether the normals predicted by this model utilize the full range of the Z dimension, or only its positive + half. + default_denoising_steps (`int`, *optional*): + The minimum number of denoising diffusion steps that are required to produce a prediction of reasonable + quality with the given model. This value must be set in the model config. When the pipeline is called + without explicitly setting `num_inference_steps`, the default value is used. This is required to ensure + reasonable results with various model flavors compatible with the pipeline, such as those relying on very + short denoising schedules (`LCMScheduler`) and those with full diffusion schedules (`DDIMScheduler`). + default_processing_resolution (`int`, *optional*): + The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in + the model config. When the pipeline is called without explicitly setting `processing_resolution`, the + default value is used. This is required to ensure reasonable results with various model flavors trained + with varying optimal processing resolution values. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + supported_prediction_types = ("normals",) + + def __init__( + self, + unet: UNet2DConditionModel, + vae: AutoencoderKL, + scheduler: Union[DDIMScheduler, LCMScheduler], + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + prediction_type: Optional[str] = None, + use_full_z_range: Optional[bool] = True, + default_denoising_steps: Optional[int] = None, + default_processing_resolution: Optional[int] = None, + ): + super().__init__() + + if prediction_type not in self.supported_prediction_types: + logger.warning( + f"Potentially unsupported `prediction_type='{prediction_type}'`; values supported by the pipeline: " + f"{self.supported_prediction_types}." + ) + + self.register_modules( + unet=unet, + vae=vae, + scheduler=scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + ) + self.register_to_config( + use_full_z_range=use_full_z_range, + default_denoising_steps=default_denoising_steps, + default_processing_resolution=default_processing_resolution, + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + self.use_full_z_range = use_full_z_range + self.default_denoising_steps = default_denoising_steps + self.default_processing_resolution = default_processing_resolution + + self.empty_text_embedding = None + + self.image_processor = MarigoldImageProcessor(vae_scale_factor=self.vae_scale_factor) + + def check_inputs( + self, + image: PipelineImageInput, + num_inference_steps: int, + ensemble_size: int, + processing_resolution: int, + resample_method_input: str, + resample_method_output: str, + batch_size: int, + ensembling_kwargs: Optional[Dict[str, Any]], + latents: Optional[torch.Tensor], + generator: Optional[Union[torch.Generator, List[torch.Generator]]], + output_type: str, + output_uncertainty: bool, + ) -> int: + if num_inference_steps is None: + raise ValueError("`num_inference_steps` is not specified and could not be resolved from the model config.") + if num_inference_steps < 1: + raise ValueError("`num_inference_steps` must be positive.") + if ensemble_size < 1: + raise ValueError("`ensemble_size` must be positive.") + if ensemble_size == 2: + logger.warning( + "`ensemble_size` == 2 results are similar to no ensembling (1); " + "consider increasing the value to at least 3." + ) + if ensemble_size == 1 and output_uncertainty: + raise ValueError( + "Computing uncertainty by setting `output_uncertainty=True` also requires setting `ensemble_size` " + "greater than 1." + ) + if processing_resolution is None: + raise ValueError( + "`processing_resolution` is not specified and could not be resolved from the model config." + ) + if processing_resolution < 0: + raise ValueError( + "`processing_resolution` must be non-negative: 0 for native resolution, or any positive value for " + "downsampled processing." + ) + if processing_resolution % self.vae_scale_factor != 0: + raise ValueError(f"`processing_resolution` must be a multiple of {self.vae_scale_factor}.") + if resample_method_input not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"): + raise ValueError( + "`resample_method_input` takes string values compatible with PIL library: " + "nearest, nearest-exact, bilinear, bicubic, area." + ) + if resample_method_output not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"): + raise ValueError( + "`resample_method_output` takes string values compatible with PIL library: " + "nearest, nearest-exact, bilinear, bicubic, area." + ) + if batch_size < 1: + raise ValueError("`batch_size` must be positive.") + if output_type not in ["pt", "np"]: + raise ValueError("`output_type` must be one of `pt` or `np`.") + if latents is not None and generator is not None: + raise ValueError("`latents` and `generator` cannot be used together.") + if ensembling_kwargs is not None: + if not isinstance(ensembling_kwargs, dict): + raise ValueError("`ensembling_kwargs` must be a dictionary.") + if "reduction" in ensembling_kwargs and ensembling_kwargs["reduction"] not in ("closest", "mean"): + raise ValueError("`ensembling_kwargs['reduction']` can be either `'closest'` or `'mean'`.") + + # image checks + num_images = 0 + W, H = None, None + if not isinstance(image, list): + image = [image] + for i, img in enumerate(image): + if isinstance(img, np.ndarray) or torch.is_tensor(img): + if img.ndim not in (2, 3, 4): + raise ValueError(f"`image[{i}]` has unsupported dimensions or shape: {img.shape}.") + H_i, W_i = img.shape[-2:] + N_i = 1 + if img.ndim == 4: + N_i = img.shape[0] + elif isinstance(img, Image.Image): + W_i, H_i = img.size + N_i = 1 + else: + raise ValueError(f"Unsupported `image[{i}]` type: {type(img)}.") + if W is None: + W, H = W_i, H_i + elif (W, H) != (W_i, H_i): + raise ValueError( + f"Input `image[{i}]` has incompatible dimensions {(W_i, H_i)} with the previous images {(W, H)}" + ) + num_images += N_i + + # latents checks + if latents is not None: + if not torch.is_tensor(latents): + raise ValueError("`latents` must be a torch.Tensor.") + if latents.dim() != 4: + raise ValueError(f"`latents` has unsupported dimensions or shape: {latents.shape}.") + + if processing_resolution > 0: + max_orig = max(H, W) + new_H = H * processing_resolution // max_orig + new_W = W * processing_resolution // max_orig + if new_H == 0 or new_W == 0: + raise ValueError(f"Extreme aspect ratio of the input image: [{W} x {H}]") + W, H = new_W, new_H + w = (W + self.vae_scale_factor - 1) // self.vae_scale_factor + h = (H + self.vae_scale_factor - 1) // self.vae_scale_factor + shape_expected = (num_images * ensemble_size, self.vae.config.latent_channels, h, w) + + if latents.shape != shape_expected: + raise ValueError(f"`latents` has unexpected shape={latents.shape} expected={shape_expected}.") + + # generator checks + if generator is not None: + if isinstance(generator, list): + if len(generator) != num_images * ensemble_size: + raise ValueError( + "The number of generators must match the total number of ensemble members for all input images." + ) + if not all(g.device.type == generator[0].device.type for g in generator): + raise ValueError("`generator` device placement is not consistent in the list.") + elif not isinstance(generator, torch.Generator): + raise ValueError(f"Unsupported generator type: {type(generator)}.") + + return num_images + + def progress_bar(self, iterable=None, total=None, desc=None, leave=True): + if not hasattr(self, "_progress_bar_config"): + self._progress_bar_config = {} + elif not isinstance(self._progress_bar_config, dict): + raise ValueError( + f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." + ) + + progress_bar_config = dict(**self._progress_bar_config) + progress_bar_config["desc"] = progress_bar_config.get("desc", desc) + progress_bar_config["leave"] = progress_bar_config.get("leave", leave) + if iterable is not None: + return tqdm(iterable, **progress_bar_config) + elif total is not None: + return tqdm(total=total, **progress_bar_config) + else: + raise ValueError("Either `total` or `iterable` has to be defined.") + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput, + num_inference_steps: Optional[int] = None, + ensemble_size: int = 1, + processing_resolution: Optional[int] = None, + match_input_resolution: bool = True, + resample_method_input: str = "bilinear", + resample_method_output: str = "bilinear", + batch_size: int = 1, + ensembling_kwargs: Optional[Dict[str, Any]] = None, + latents: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: str = "np", + output_uncertainty: bool = False, + output_latent: bool = False, + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline. + + Args: + image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`), + `List[torch.Tensor]`: An input image or images used as an input for the normals estimation task. For + arrays and tensors, the expected value range is between `[0, 1]`. Passing a batch of images is possible + by providing a four-dimensional array or a tensor. Additionally, a list of images of two- or + three-dimensional arrays or tensors can be passed. In the latter case, all list elements must have the + same width and height. + num_inference_steps (`int`, *optional*, defaults to `None`): + Number of denoising diffusion steps during inference. The default value `None` results in automatic + selection. The number of steps should be at least 10 with the full Marigold models, and between 1 and 4 + for Marigold-LCM models. + ensemble_size (`int`, defaults to `1`): + Number of ensemble predictions. Recommended values are 5 and higher for better precision, or 1 for + faster inference. + processing_resolution (`int`, *optional*, defaults to `None`): + Effective processing resolution. When set to `0`, matches the larger input image dimension. This + produces crisper predictions, but may also lead to the overall loss of global context. The default + value `None` resolves to the optimal value from the model config. + match_input_resolution (`bool`, *optional*, defaults to `True`): + When enabled, the output prediction is resized to match the input dimensions. When disabled, the longer + side of the output will equal to `processing_resolution`. + resample_method_input (`str`, *optional*, defaults to `"bilinear"`): + Resampling method used to resize input images to `processing_resolution`. The accepted values are: + `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`. + resample_method_output (`str`, *optional*, defaults to `"bilinear"`): + Resampling method used to resize output predictions to match the input resolution. The accepted values + are `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`. + batch_size (`int`, *optional*, defaults to `1`): + Batch size; only matters when setting `ensemble_size` or passing a tensor of images. + ensembling_kwargs (`dict`, *optional*, defaults to `None`) + Extra dictionary with arguments for precise ensembling control. The following options are available: + - reduction (`str`, *optional*, defaults to `"closest"`): Defines the ensembling function applied in + every pixel location, can be either `"closest"` or `"mean"`. + latents (`torch.Tensor`, *optional*, defaults to `None`): + Latent noise tensors to replace the random initialization. These can be taken from the previous + function call's output. + generator (`torch.Generator`, or `List[torch.Generator]`, *optional*, defaults to `None`): + Random number generator object to ensure reproducibility. + output_type (`str`, *optional*, defaults to `"np"`): + Preferred format of the output's `prediction` and the optional `uncertainty` fields. The accepted + values are: `"np"` (numpy array) or `"pt"` (torch tensor). + output_uncertainty (`bool`, *optional*, defaults to `False`): + When enabled, the output's `uncertainty` field contains the predictive uncertainty map, provided that + the `ensemble_size` argument is set to a value above 2. + output_latent (`bool`, *optional*, defaults to `False`): + When enabled, the output's `latent` field contains the latent codes corresponding to the predictions + within the ensemble. These codes can be saved, modified, and used for subsequent calls with the + `latents` argument. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.marigold.MarigoldDepthOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`~pipelines.marigold.MarigoldNormalsOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.marigold.MarigoldNormalsOutput`] is returned, otherwise a + `tuple` is returned where the first element is the prediction, the second element is the uncertainty + (or `None`), and the third is the latent (or `None`). + """ + + # 0. Resolving variables. + device = self._execution_device + dtype = self.dtype + + # Model-specific optimal default values leading to fast and reasonable results. + if num_inference_steps is None: + num_inference_steps = self.default_denoising_steps + if processing_resolution is None: + processing_resolution = self.default_processing_resolution + + # 1. Check inputs. + num_images = self.check_inputs( + image, + num_inference_steps, + ensemble_size, + processing_resolution, + resample_method_input, + resample_method_output, + batch_size, + ensembling_kwargs, + latents, + generator, + output_type, + output_uncertainty, + ) + + # 2. Prepare empty text conditioning. + # Model invocation: self.tokenizer, self.text_encoder. + if self.empty_text_embedding is None: + prompt = "" + text_inputs = self.tokenizer( + prompt, + padding="do_not_pad", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids.to(device) + self.empty_text_embedding = self.text_encoder(text_input_ids)[0] # [1,2,1024] + + # 3. Preprocess input images. This function loads input image or images of compatible dimensions `(H, W)`, + # optionally downsamples them to the `processing_resolution` `(PH, PW)`, where + # `max(PH, PW) == processing_resolution`, and pads the dimensions to `(PPH, PPW)` such that these values are + # divisible by the latent space downscaling factor (typically 8 in Stable Diffusion). The default value `None` + # of `processing_resolution` resolves to the optimal value from the model config. It is a recommended mode of + # operation and leads to the most reasonable results. Using the native image resolution or any other processing + # resolution can lead to loss of either fine details or global context in the output predictions. + image, padding, original_resolution = self.image_processor.preprocess( + image, processing_resolution, resample_method_input, device, dtype + ) # [N,3,PPH,PPW] + + # 4. Encode input image into latent space. At this step, each of the `N` input images is represented with `E` + # ensemble members. Each ensemble member is an independent diffused prediction, just initialized independently. + # Latents of each such predictions across all input images and all ensemble members are represented in the + # `pred_latent` variable. The variable `image_latent` is of the same shape: it contains each input image encoded + # into latent space and replicated `E` times. The latents can be either generated (see `generator` to ensure + # reproducibility), or passed explicitly via the `latents` argument. The latter can be set outside the pipeline + # code. For example, in the Marigold-LCM video processing demo, the latents initialization of a frame is taken + # as a convex combination of the latents output of the pipeline for the previous frame and a newly-sampled + # noise. This behavior can be achieved by setting the `output_latent` argument to `True`. The latent space + # dimensions are `(h, w)`. Encoding into latent space happens in batches of size `batch_size`. + # Model invocation: self.vae.encoder. + image_latent, pred_latent = self.prepare_latents( + image, latents, generator, ensemble_size, batch_size + ) # [N*E,4,h,w], [N*E,4,h,w] + + del image + + batch_empty_text_embedding = self.empty_text_embedding.to(device=device, dtype=dtype).repeat( + batch_size, 1, 1 + ) # [B,1024,2] + + # 5. Process the denoising loop. All `N * E` latents are processed sequentially in batches of size `batch_size`. + # The unet model takes concatenated latent spaces of the input image and the predicted modality as an input, and + # outputs noise for the predicted modality's latent space. The number of denoising diffusion steps is defined by + # `num_inference_steps`. It is either set directly, or resolves to the optimal value specific to the loaded + # model. + # Model invocation: self.unet. + pred_latents = [] + + for i in self.progress_bar( + range(0, num_images * ensemble_size, batch_size), leave=True, desc="Marigold predictions..." + ): + batch_image_latent = image_latent[i : i + batch_size] # [B,4,h,w] + batch_pred_latent = pred_latent[i : i + batch_size] # [B,4,h,w] + effective_batch_size = batch_image_latent.shape[0] + text = batch_empty_text_embedding[:effective_batch_size] # [B,2,1024] + + self.scheduler.set_timesteps(num_inference_steps, device=device) + for t in self.progress_bar(self.scheduler.timesteps, leave=False, desc="Diffusion steps..."): + batch_latent = torch.cat([batch_image_latent, batch_pred_latent], dim=1) # [B,8,h,w] + noise = self.unet(batch_latent, t, encoder_hidden_states=text, return_dict=False)[0] # [B,4,h,w] + batch_pred_latent = self.scheduler.step( + noise, t, batch_pred_latent, generator=generator + ).prev_sample # [B,4,h,w] + + pred_latents.append(batch_pred_latent) + + pred_latent = torch.cat(pred_latents, dim=0) # [N*E,4,h,w] + + del ( + pred_latents, + image_latent, + batch_empty_text_embedding, + batch_image_latent, + batch_pred_latent, + text, + batch_latent, + noise, + ) + + # 6. Decode predictions from latent into pixel space. The resulting `N * E` predictions have shape `(PPH, PPW)`, + # which requires slight postprocessing. Decoding into pixel space happens in batches of size `batch_size`. + # Model invocation: self.vae.decoder. + prediction = torch.cat( + [ + self.decode_prediction(pred_latent[i : i + batch_size]) + for i in range(0, pred_latent.shape[0], batch_size) + ], + dim=0, + ) # [N*E,3,PPH,PPW] + + if not output_latent: + pred_latent = None + + # 7. Remove padding. The output shape is (PH, PW). + prediction = self.image_processor.unpad_image(prediction, padding) # [N*E,3,PH,PW] + + # 8. Ensemble and compute uncertainty (when `output_uncertainty` is set). This code treats each of the `N` + # groups of `E` ensemble predictions independently. For each group it computes an ensembled prediction of shape + # `(PH, PW)` and an optional uncertainty map of the same dimensions. After computing this pair of outputs for + # each group independently, it stacks them respectively into batches of `N` almost final predictions and + # uncertainty maps. + uncertainty = None + if ensemble_size > 1: + prediction = prediction.reshape(num_images, ensemble_size, *prediction.shape[1:]) # [N,E,3,PH,PW] + prediction = [ + self.ensemble_normals(prediction[i], output_uncertainty, **(ensembling_kwargs or {})) + for i in range(num_images) + ] # [ [[1,3,PH,PW], [1,1,PH,PW]], ... ] + prediction, uncertainty = zip(*prediction) # [[1,3,PH,PW], ... ], [[1,1,PH,PW], ... ] + prediction = torch.cat(prediction, dim=0) # [N,3,PH,PW] + if output_uncertainty: + uncertainty = torch.cat(uncertainty, dim=0) # [N,1,PH,PW] + else: + uncertainty = None + + # 9. If `match_input_resolution` is set, the output prediction and the uncertainty are upsampled to match the + # input resolution `(H, W)`. This step may introduce upsampling artifacts, and therefore can be disabled. + # After upsampling, the native resolution normal maps are renormalized to unit length to reduce the artifacts. + # Depending on the downstream use-case, upsampling can be also chosen based on the tolerated artifacts by + # setting the `resample_method_output` parameter (e.g., to `"nearest"`). + if match_input_resolution: + prediction = self.image_processor.resize_antialias( + prediction, original_resolution, resample_method_output, is_aa=False + ) # [N,3,H,W] + prediction = self.normalize_normals(prediction) # [N,3,H,W] + if uncertainty is not None and output_uncertainty: + uncertainty = self.image_processor.resize_antialias( + uncertainty, original_resolution, resample_method_output, is_aa=False + ) # [N,1,H,W] + + # 10. Prepare the final outputs. + if output_type == "np": + prediction = self.image_processor.pt_to_numpy(prediction) # [N,H,W,3] + if uncertainty is not None and output_uncertainty: + uncertainty = self.image_processor.pt_to_numpy(uncertainty) # [N,H,W,1] + + # 11. Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (prediction, uncertainty, pred_latent) + + return MarigoldNormalsOutput( + prediction=prediction, + uncertainty=uncertainty, + latent=pred_latent, + ) + + # Copied from diffusers.pipelines.marigold.pipeline_marigold_depth.MarigoldDepthPipeline.prepare_latents + def prepare_latents( + self, + image: torch.Tensor, + latents: Optional[torch.Tensor], + generator: Optional[torch.Generator], + ensemble_size: int, + batch_size: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + def retrieve_latents(encoder_output): + if hasattr(encoder_output, "latent_dist"): + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + image_latent = torch.cat( + [ + retrieve_latents(self.vae.encode(image[i : i + batch_size])) + for i in range(0, image.shape[0], batch_size) + ], + dim=0, + ) # [N,4,h,w] + image_latent = image_latent * self.vae.config.scaling_factor + image_latent = image_latent.repeat_interleave(ensemble_size, dim=0) # [N*E,4,h,w] + + pred_latent = latents + if pred_latent is None: + pred_latent = randn_tensor( + image_latent.shape, + generator=generator, + device=image_latent.device, + dtype=image_latent.dtype, + ) # [N*E,4,h,w] + + return image_latent, pred_latent + + def decode_prediction(self, pred_latent: torch.Tensor) -> torch.Tensor: + if pred_latent.dim() != 4 or pred_latent.shape[1] != self.vae.config.latent_channels: + raise ValueError( + f"Expecting 4D tensor of shape [B,{self.vae.config.latent_channels},H,W]; got {pred_latent.shape}." + ) + + prediction = self.vae.decode(pred_latent / self.vae.config.scaling_factor, return_dict=False)[0] # [B,3,H,W] + + prediction = torch.clip(prediction, -1.0, 1.0) + + if not self.use_full_z_range: + prediction[:, 2, :, :] *= 0.5 + prediction[:, 2, :, :] += 0.5 + + prediction = self.normalize_normals(prediction) # [B,3,H,W] + + return prediction # [B,3,H,W] + + @staticmethod + def normalize_normals(normals: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: + if normals.dim() != 4 or normals.shape[1] != 3: + raise ValueError(f"Expecting 4D tensor of shape [B,3,H,W]; got {normals.shape}.") + + norm = torch.norm(normals, dim=1, keepdim=True) + normals /= norm.clamp(min=eps) + + return normals + + @staticmethod + def ensemble_normals( + normals: torch.Tensor, output_uncertainty: bool, reduction: str = "closest" + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Ensembles the normals maps represented by the `normals` tensor with expected shape `(B, 3, H, W)`, where B is + the number of ensemble members for a given prediction of size `(H x W)`. + + Args: + normals (`torch.Tensor`): + Input ensemble normals maps. + output_uncertainty (`bool`, *optional*, defaults to `False`): + Whether to output uncertainty map. + reduction (`str`, *optional*, defaults to `"closest"`): + Reduction method used to ensemble aligned predictions. The accepted values are: `"closest"` and + `"mean"`. + + Returns: + A tensor of aligned and ensembled normals maps with shape `(1, 3, H, W)` and optionally a tensor of + uncertainties of shape `(1, 1, H, W)`. + """ + if normals.dim() != 4 or normals.shape[1] != 3: + raise ValueError(f"Expecting 4D tensor of shape [B,3,H,W]; got {normals.shape}.") + if reduction not in ("closest", "mean"): + raise ValueError(f"Unrecognized reduction method: {reduction}.") + + mean_normals = normals.mean(dim=0, keepdim=True) # [1,3,H,W] + mean_normals = MarigoldNormalsPipeline.normalize_normals(mean_normals) # [1,3,H,W] + + sim_cos = (mean_normals * normals).sum(dim=1, keepdim=True) # [E,1,H,W] + sim_cos = sim_cos.clamp(-1, 1) # required to avoid NaN in uncertainty with fp16 + + uncertainty = None + if output_uncertainty: + uncertainty = sim_cos.arccos() # [E,1,H,W] + uncertainty = uncertainty.mean(dim=0, keepdim=True) / np.pi # [1,1,H,W] + + if reduction == "mean": + return mean_normals, uncertainty # [1,3,H,W], [1,1,H,W] + + closest_indices = sim_cos.argmax(dim=0, keepdim=True) # [1,1,H,W] + closest_indices = closest_indices.repeat(1, 3, 1, 1) # [1,3,H,W] + closest_normals = torch.gather(normals, 0, closest_indices) # [1,3,H,W] + + return closest_normals, uncertainty # [1,3,H,W], [1,1,H,W] diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 04f91d758b94..7ab0a94e5677 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -68,6 +68,7 @@ is_k_diffusion_available, is_k_diffusion_version, is_librosa_available, + is_matplotlib_available, is_note_seq_available, is_notebook, is_onnx_available, diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 0583cf839ff7..df436bc46c06 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -692,6 +692,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class MarigoldDepthPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class MarigoldNormalsPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class MusicLDMPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index b8ce2d7c0466..6f70f5888910 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -295,6 +295,13 @@ except importlib_metadata.PackageNotFoundError: _torchvision_available = False +_matplotlib_available = importlib.util.find_spec("matplotlib") is not None +try: + _matplotlib_version = importlib_metadata.version("matplotlib") + logger.debug(f"Successfully imported matplotlib version {_matplotlib_version}") +except importlib_metadata.PackageNotFoundError: + _matplotlib_available = False + _timm_available = importlib.util.find_spec("timm") is not None if _timm_available: try: @@ -425,6 +432,10 @@ def is_torchvision_available(): return _torchvision_available +def is_matplotlib_available(): + return _matplotlib_available + + def is_safetensors_available(): return _safetensors_available diff --git a/tests/pipelines/marigold/__init__.py b/tests/pipelines/marigold/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/marigold/test_marigold_depth.py b/tests/pipelines/marigold/test_marigold_depth.py new file mode 100644 index 000000000000..24d1981b8fb2 --- /dev/null +++ b/tests/pipelines/marigold/test_marigold_depth.py @@ -0,0 +1,459 @@ +# Copyright 2024 Marigold authors, PRS ETH Zurich. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information and citation instructions are available on the +# Marigold project website: https://marigoldmonodepth.github.io +# -------------------------------------------------------------------------- +import gc +import random +import unittest + +import numpy as np +import torch +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from diffusers import ( + AutoencoderKL, + AutoencoderTiny, + LCMScheduler, + MarigoldDepthPipeline, + UNet2DConditionModel, +) +from diffusers.utils.testing_utils import ( + enable_full_determinism, + floats_tensor, + load_image, + require_torch_gpu, + slow, +) + +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class MarigoldDepthPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = MarigoldDepthPipeline + params = frozenset(["image"]) + batch_params = frozenset(["image"]) + image_params = frozenset(["image"]) + image_latents_params = frozenset(["latents"]) + callback_cfg_params = frozenset([]) + test_xformers_attention = False + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "output_type", + ] + ) + + def get_dummy_components(self, time_cond_proj_dim=None): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + time_cond_proj_dim=time_cond_proj_dim, + sample_size=32, + in_channels=8, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + scheduler = LCMScheduler( + beta_start=0.00085, + beta_end=0.012, + prediction_type="v_prediction", + set_alpha_to_one=False, + steps_offset=1, + beta_schedule="scaled_linear", + clip_sample=False, + thresholding=False, + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + components = { + "unet": unet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "prediction_type": "depth", + "scale_invariant": True, + "shift_invariant": True, + } + return components + + def get_dummy_tiny_autoencoder(self): + return AutoencoderTiny(in_channels=3, out_channels=3, latent_channels=4) + + def get_dummy_inputs(self, device, seed=0): + image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + image = image / 2 + 0.5 + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "image": image, + "num_inference_steps": 1, + "processing_resolution": 0, + "generator": generator, + "output_type": "np", + } + return inputs + + def _test_marigold_depth( + self, + generator_seed: int = 0, + expected_slice: np.ndarray = None, + atol: float = 1e-4, + **pipe_kwargs, + ): + device = "cpu" + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + pipe_inputs = self.get_dummy_inputs(device, seed=generator_seed) + pipe_inputs.update(**pipe_kwargs) + + prediction = pipe(**pipe_inputs).prediction + + prediction_slice = prediction[0, -3:, -3:, -1].flatten() + + if pipe_inputs.get("match_input_resolution", True): + self.assertEqual(prediction.shape, (1, 32, 32, 1), "Unexpected output resolution") + else: + self.assertTrue(prediction.shape[0] == 1 and prediction.shape[3] == 1, "Unexpected output dimensions") + self.assertEqual( + max(prediction.shape[1:3]), + pipe_inputs.get("processing_resolution", 768), + "Unexpected output resolution", + ) + + self.assertTrue(np.allclose(prediction_slice, expected_slice, atol=atol)) + + def test_marigold_depth_dummy_defaults(self): + self._test_marigold_depth( + expected_slice=np.array([0.4529, 0.5184, 0.4985, 0.4355, 0.4273, 0.4153, 0.5229, 0.4818, 0.4627]), + ) + + def test_marigold_depth_dummy_G0_S1_P32_E1_B1_M1(self): + self._test_marigold_depth( + generator_seed=0, + expected_slice=np.array([0.4529, 0.5184, 0.4985, 0.4355, 0.4273, 0.4153, 0.5229, 0.4818, 0.4627]), + num_inference_steps=1, + processing_resolution=32, + ensemble_size=1, + batch_size=1, + match_input_resolution=True, + ) + + def test_marigold_depth_dummy_G0_S1_P16_E1_B1_M1(self): + self._test_marigold_depth( + generator_seed=0, + expected_slice=np.array([0.4511, 0.4531, 0.4542, 0.5024, 0.4987, 0.4969, 0.5281, 0.5215, 0.5182]), + num_inference_steps=1, + processing_resolution=16, + ensemble_size=1, + batch_size=1, + match_input_resolution=True, + ) + + def test_marigold_depth_dummy_G2024_S1_P32_E1_B1_M1(self): + self._test_marigold_depth( + generator_seed=2024, + expected_slice=np.array([0.4671, 0.4739, 0.5130, 0.4308, 0.4411, 0.4720, 0.5064, 0.4796, 0.4795]), + num_inference_steps=1, + processing_resolution=32, + ensemble_size=1, + batch_size=1, + match_input_resolution=True, + ) + + def test_marigold_depth_dummy_G0_S2_P32_E1_B1_M1(self): + self._test_marigold_depth( + generator_seed=0, + expected_slice=np.array([0.4165, 0.4485, 0.4647, 0.4003, 0.4577, 0.5074, 0.5106, 0.5077, 0.5042]), + num_inference_steps=2, + processing_resolution=32, + ensemble_size=1, + batch_size=1, + match_input_resolution=True, + ) + + def test_marigold_depth_dummy_G0_S1_P64_E1_B1_M1(self): + self._test_marigold_depth( + generator_seed=0, + expected_slice=np.array([0.4817, 0.5425, 0.5146, 0.5367, 0.5034, 0.4743, 0.4395, 0.4734, 0.4399]), + num_inference_steps=1, + processing_resolution=64, + ensemble_size=1, + batch_size=1, + match_input_resolution=True, + ) + + def test_marigold_depth_dummy_G0_S1_P32_E3_B1_M1(self): + self._test_marigold_depth( + generator_seed=0, + expected_slice=np.array([0.3260, 0.3591, 0.2837, 0.2971, 0.2750, 0.2426, 0.4200, 0.3588, 0.3254]), + num_inference_steps=1, + processing_resolution=32, + ensemble_size=3, + ensembling_kwargs={"reduction": "mean"}, + batch_size=1, + match_input_resolution=True, + ) + + def test_marigold_depth_dummy_G0_S1_P32_E4_B2_M1(self): + self._test_marigold_depth( + generator_seed=0, + expected_slice=np.array([0.3180, 0.4194, 0.3013, 0.2902, 0.3245, 0.2897, 0.4718, 0.4174, 0.3705]), + num_inference_steps=1, + processing_resolution=32, + ensemble_size=4, + ensembling_kwargs={"reduction": "mean"}, + batch_size=2, + match_input_resolution=True, + ) + + def test_marigold_depth_dummy_G0_S1_P16_E1_B1_M0(self): + self._test_marigold_depth( + generator_seed=0, + expected_slice=np.array([0.5515, 0.4588, 0.4197, 0.4741, 0.4229, 0.4328, 0.5333, 0.5314, 0.5182]), + num_inference_steps=1, + processing_resolution=16, + ensemble_size=1, + batch_size=1, + match_input_resolution=False, + ) + + def test_marigold_depth_dummy_no_num_inference_steps(self): + with self.assertRaises(ValueError) as e: + self._test_marigold_depth( + num_inference_steps=None, + expected_slice=np.array([0.0]), + ) + self.assertIn("num_inference_steps", str(e)) + + def test_marigold_depth_dummy_no_processing_resolution(self): + with self.assertRaises(ValueError) as e: + self._test_marigold_depth( + processing_resolution=None, + expected_slice=np.array([0.0]), + ) + self.assertIn("processing_resolution", str(e)) + + +@slow +@require_torch_gpu +class MarigoldDepthPipelineIntegrationTests(unittest.TestCase): + def setUp(self): + super().setUp() + gc.collect() + torch.cuda.empty_cache() + + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def _test_marigold_depth( + self, + is_fp16: bool = True, + device: str = "cuda", + generator_seed: int = 0, + expected_slice: np.ndarray = None, + model_id: str = "prs-eth/marigold-lcm-v1-0", + image_url: str = "https://marigoldmonodepth.github.io/images/einstein.jpg", + atol: float = 1e-4, + **pipe_kwargs, + ): + from_pretrained_kwargs = {} + if is_fp16: + from_pretrained_kwargs["variant"] = "fp16" + from_pretrained_kwargs["torch_dtype"] = torch.float16 + + pipe = MarigoldDepthPipeline.from_pretrained(model_id, **from_pretrained_kwargs) + if device == "cuda": + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device=device).manual_seed(generator_seed) + + image = load_image(image_url) + width, height = image.size + + prediction = pipe(image, generator=generator, **pipe_kwargs).prediction + + prediction_slice = prediction[0, -3:, -3:, -1].flatten() + + if pipe_kwargs.get("match_input_resolution", True): + self.assertEqual(prediction.shape, (1, height, width, 1), "Unexpected output resolution") + else: + self.assertTrue(prediction.shape[0] == 1 and prediction.shape[3] == 1, "Unexpected output dimensions") + self.assertEqual( + max(prediction.shape[1:3]), + pipe_kwargs.get("processing_resolution", 768), + "Unexpected output resolution", + ) + + self.assertTrue(np.allclose(prediction_slice, expected_slice, atol=atol)) + + def test_marigold_depth_einstein_f32_cpu_G0_S1_P32_E1_B1_M1(self): + self._test_marigold_depth( + is_fp16=False, + device="cpu", + generator_seed=0, + expected_slice=np.array([0.4323, 0.4323, 0.4323, 0.4323, 0.4323, 0.4323, 0.4323, 0.4323, 0.4323]), + num_inference_steps=1, + processing_resolution=32, + ensemble_size=1, + batch_size=1, + match_input_resolution=True, + ) + + def test_marigold_depth_einstein_f32_cuda_G0_S1_P768_E1_B1_M1(self): + self._test_marigold_depth( + is_fp16=False, + device="cuda", + generator_seed=0, + expected_slice=np.array([0.1244, 0.1265, 0.1292, 0.1240, 0.1252, 0.1266, 0.1246, 0.1226, 0.1180]), + num_inference_steps=1, + processing_resolution=768, + ensemble_size=1, + batch_size=1, + match_input_resolution=True, + ) + + def test_marigold_depth_einstein_f16_cuda_G0_S1_P768_E1_B1_M1(self): + self._test_marigold_depth( + is_fp16=True, + device="cuda", + generator_seed=0, + expected_slice=np.array([0.1241, 0.1262, 0.1290, 0.1238, 0.1250, 0.1265, 0.1244, 0.1225, 0.1179]), + num_inference_steps=1, + processing_resolution=768, + ensemble_size=1, + batch_size=1, + match_input_resolution=True, + ) + + def test_marigold_depth_einstein_f16_cuda_G2024_S1_P768_E1_B1_M1(self): + self._test_marigold_depth( + is_fp16=True, + device="cuda", + generator_seed=2024, + expected_slice=np.array([0.1710, 0.1725, 0.1738, 0.1700, 0.1700, 0.1696, 0.1698, 0.1663, 0.1592]), + num_inference_steps=1, + processing_resolution=768, + ensemble_size=1, + batch_size=1, + match_input_resolution=True, + ) + + def test_marigold_depth_einstein_f16_cuda_G0_S2_P768_E1_B1_M1(self): + self._test_marigold_depth( + is_fp16=True, + device="cuda", + generator_seed=0, + expected_slice=np.array([0.1085, 0.1098, 0.1110, 0.1081, 0.1085, 0.1082, 0.1085, 0.1057, 0.0996]), + num_inference_steps=2, + processing_resolution=768, + ensemble_size=1, + batch_size=1, + match_input_resolution=True, + ) + + def test_marigold_depth_einstein_f16_cuda_G0_S1_P512_E1_B1_M1(self): + self._test_marigold_depth( + is_fp16=True, + device="cuda", + generator_seed=0, + expected_slice=np.array([0.2683, 0.2693, 0.2698, 0.2666, 0.2632, 0.2615, 0.2656, 0.2603, 0.2573]), + num_inference_steps=1, + processing_resolution=512, + ensemble_size=1, + batch_size=1, + match_input_resolution=True, + ) + + def test_marigold_depth_einstein_f16_cuda_G0_S1_P768_E3_B1_M1(self): + self._test_marigold_depth( + is_fp16=True, + device="cuda", + generator_seed=0, + expected_slice=np.array([0.1200, 0.1215, 0.1237, 0.1193, 0.1197, 0.1202, 0.1196, 0.1166, 0.1109]), + num_inference_steps=1, + processing_resolution=768, + ensemble_size=3, + ensembling_kwargs={"reduction": "mean"}, + batch_size=1, + match_input_resolution=True, + ) + + def test_marigold_depth_einstein_f16_cuda_G0_S1_P768_E4_B2_M1(self): + self._test_marigold_depth( + is_fp16=True, + device="cuda", + generator_seed=0, + expected_slice=np.array([0.1121, 0.1135, 0.1155, 0.1111, 0.1115, 0.1118, 0.1111, 0.1079, 0.1019]), + num_inference_steps=1, + processing_resolution=768, + ensemble_size=4, + ensembling_kwargs={"reduction": "mean"}, + batch_size=2, + match_input_resolution=True, + ) + + def test_marigold_depth_einstein_f16_cuda_G0_S1_P512_E1_B1_M0(self): + self._test_marigold_depth( + is_fp16=True, + device="cuda", + generator_seed=0, + expected_slice=np.array([0.2671, 0.2690, 0.2720, 0.2659, 0.2676, 0.2739, 0.2664, 0.2686, 0.2573]), + num_inference_steps=1, + processing_resolution=512, + ensemble_size=1, + batch_size=1, + match_input_resolution=False, + ) diff --git a/tests/pipelines/marigold/test_marigold_normals.py b/tests/pipelines/marigold/test_marigold_normals.py new file mode 100644 index 000000000000..c86c600be8e5 --- /dev/null +++ b/tests/pipelines/marigold/test_marigold_normals.py @@ -0,0 +1,459 @@ +# Copyright 2024 Marigold authors, PRS ETH Zurich. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information and citation instructions are available on the +# Marigold project website: https://marigoldmonodepth.github.io +# -------------------------------------------------------------------------- +import gc +import random +import unittest + +import numpy as np +import torch +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from diffusers import ( + AutoencoderKL, + AutoencoderTiny, + LCMScheduler, + MarigoldNormalsPipeline, + UNet2DConditionModel, +) +from diffusers.utils.testing_utils import ( + enable_full_determinism, + floats_tensor, + load_image, + require_torch_gpu, + slow, +) + +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class MarigoldNormalsPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = MarigoldNormalsPipeline + params = frozenset(["image"]) + batch_params = frozenset(["image"]) + image_params = frozenset(["image"]) + image_latents_params = frozenset(["latents"]) + callback_cfg_params = frozenset([]) + test_xformers_attention = False + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "output_type", + ] + ) + + def get_dummy_components(self, time_cond_proj_dim=None): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + time_cond_proj_dim=time_cond_proj_dim, + sample_size=32, + in_channels=8, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + torch.manual_seed(0) + scheduler = LCMScheduler( + beta_start=0.00085, + beta_end=0.012, + prediction_type="v_prediction", + set_alpha_to_one=False, + steps_offset=1, + beta_schedule="scaled_linear", + clip_sample=False, + thresholding=False, + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + components = { + "unet": unet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "prediction_type": "normals", + "use_full_z_range": True, + } + return components + + def get_dummy_tiny_autoencoder(self): + return AutoencoderTiny(in_channels=3, out_channels=3, latent_channels=4) + + def get_dummy_inputs(self, device, seed=0): + image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + image = image / 2 + 0.5 + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "image": image, + "num_inference_steps": 1, + "processing_resolution": 0, + "generator": generator, + "output_type": "np", + } + return inputs + + def _test_marigold_normals( + self, + generator_seed: int = 0, + expected_slice: np.ndarray = None, + atol: float = 1e-4, + **pipe_kwargs, + ): + device = "cpu" + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + pipe_inputs = self.get_dummy_inputs(device, seed=generator_seed) + pipe_inputs.update(**pipe_kwargs) + + prediction = pipe(**pipe_inputs).prediction + + prediction_slice = prediction[0, -3:, -3:, -1].flatten() + + if pipe_inputs.get("match_input_resolution", True): + self.assertEqual(prediction.shape, (1, 32, 32, 3), "Unexpected output resolution") + else: + self.assertTrue(prediction.shape[0] == 1 and prediction.shape[3] == 3, "Unexpected output dimensions") + self.assertEqual( + max(prediction.shape[1:3]), + pipe_inputs.get("processing_resolution", 768), + "Unexpected output resolution", + ) + + self.assertTrue(np.allclose(prediction_slice, expected_slice, atol=atol)) + + def test_marigold_depth_dummy_defaults(self): + self._test_marigold_normals( + expected_slice=np.array([0.0967, 0.5234, 0.1448, -0.3155, -0.2550, -0.5578, 0.6854, 0.5657, -0.1263]), + ) + + def test_marigold_depth_dummy_G0_S1_P32_E1_B1_M1(self): + self._test_marigold_normals( + generator_seed=0, + expected_slice=np.array([0.0967, 0.5234, 0.1448, -0.3155, -0.2550, -0.5578, 0.6854, 0.5657, -0.1263]), + num_inference_steps=1, + processing_resolution=32, + ensemble_size=1, + batch_size=1, + match_input_resolution=True, + ) + + def test_marigold_depth_dummy_G0_S1_P16_E1_B1_M1(self): + self._test_marigold_normals( + generator_seed=0, + expected_slice=np.array([-0.4128, -0.5918, -0.6540, 0.2446, -0.2687, -0.4607, 0.2935, -0.0483, -0.2086]), + num_inference_steps=1, + processing_resolution=16, + ensemble_size=1, + batch_size=1, + match_input_resolution=True, + ) + + def test_marigold_depth_dummy_G2024_S1_P32_E1_B1_M1(self): + self._test_marigold_normals( + generator_seed=2024, + expected_slice=np.array([0.5731, -0.7631, -0.0199, 0.1609, -0.4628, -0.7044, 0.5761, -0.3471, -0.4498]), + num_inference_steps=1, + processing_resolution=32, + ensemble_size=1, + batch_size=1, + match_input_resolution=True, + ) + + def test_marigold_depth_dummy_G0_S2_P32_E1_B1_M1(self): + self._test_marigold_normals( + generator_seed=0, + expected_slice=np.array([0.1017, -0.6823, -0.2533, 0.1988, 0.3389, 0.8478, 0.7757, 0.5220, 0.8668]), + num_inference_steps=2, + processing_resolution=32, + ensemble_size=1, + batch_size=1, + match_input_resolution=True, + ) + + def test_marigold_depth_dummy_G0_S1_P64_E1_B1_M1(self): + self._test_marigold_normals( + generator_seed=0, + expected_slice=np.array([-0.2391, 0.7969, 0.6224, 0.0698, 0.5669, -0.2167, -0.1362, -0.8945, -0.5501]), + num_inference_steps=1, + processing_resolution=64, + ensemble_size=1, + batch_size=1, + match_input_resolution=True, + ) + + def test_marigold_depth_dummy_G0_S1_P32_E3_B1_M1(self): + self._test_marigold_normals( + generator_seed=0, + expected_slice=np.array([0.3826, -0.9634, -0.3835, 0.3514, 0.0691, -0.6182, 0.8709, 0.1590, -0.2181]), + num_inference_steps=1, + processing_resolution=32, + ensemble_size=3, + ensembling_kwargs={"reduction": "mean"}, + batch_size=1, + match_input_resolution=True, + ) + + def test_marigold_depth_dummy_G0_S1_P32_E4_B2_M1(self): + self._test_marigold_normals( + generator_seed=0, + expected_slice=np.array([0.2500, -0.3928, -0.2415, 0.1133, 0.2357, -0.4223, 0.9967, 0.4859, -0.1282]), + num_inference_steps=1, + processing_resolution=32, + ensemble_size=4, + ensembling_kwargs={"reduction": "mean"}, + batch_size=2, + match_input_resolution=True, + ) + + def test_marigold_depth_dummy_G0_S1_P16_E1_B1_M0(self): + self._test_marigold_normals( + generator_seed=0, + expected_slice=np.array([0.9588, 0.3326, -0.0825, -0.0994, -0.3534, -0.4302, 0.3562, 0.4421, -0.2086]), + num_inference_steps=1, + processing_resolution=16, + ensemble_size=1, + batch_size=1, + match_input_resolution=False, + ) + + def test_marigold_depth_dummy_no_num_inference_steps(self): + with self.assertRaises(ValueError) as e: + self._test_marigold_normals( + num_inference_steps=None, + expected_slice=np.array([0.0]), + ) + self.assertIn("num_inference_steps", str(e)) + + def test_marigold_depth_dummy_no_processing_resolution(self): + with self.assertRaises(ValueError) as e: + self._test_marigold_normals( + processing_resolution=None, + expected_slice=np.array([0.0]), + ) + self.assertIn("processing_resolution", str(e)) + + +@slow +@require_torch_gpu +class MarigoldNormalsPipelineIntegrationTests(unittest.TestCase): + def setUp(self): + super().setUp() + gc.collect() + torch.cuda.empty_cache() + + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def _test_marigold_normals( + self, + is_fp16: bool = True, + device: str = "cuda", + generator_seed: int = 0, + expected_slice: np.ndarray = None, + model_id: str = "prs-eth/marigold-normals-lcm-v0-1", + image_url: str = "https://marigoldmonodepth.github.io/images/einstein.jpg", + atol: float = 1e-4, + **pipe_kwargs, + ): + from_pretrained_kwargs = {} + if is_fp16: + from_pretrained_kwargs["variant"] = "fp16" + from_pretrained_kwargs["torch_dtype"] = torch.float16 + + pipe = MarigoldNormalsPipeline.from_pretrained(model_id, **from_pretrained_kwargs) + if device == "cuda": + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device=device).manual_seed(generator_seed) + + image = load_image(image_url) + width, height = image.size + + prediction = pipe(image, generator=generator, **pipe_kwargs).prediction + + prediction_slice = prediction[0, -3:, -3:, -1].flatten() + + if pipe_kwargs.get("match_input_resolution", True): + self.assertEqual(prediction.shape, (1, height, width, 3), "Unexpected output resolution") + else: + self.assertTrue(prediction.shape[0] == 1 and prediction.shape[3] == 3, "Unexpected output dimensions") + self.assertEqual( + max(prediction.shape[1:3]), + pipe_kwargs.get("processing_resolution", 768), + "Unexpected output resolution", + ) + + self.assertTrue(np.allclose(prediction_slice, expected_slice, atol=atol)) + + def test_marigold_normals_einstein_f32_cpu_G0_S1_P32_E1_B1_M1(self): + self._test_marigold_normals( + is_fp16=False, + device="cpu", + generator_seed=0, + expected_slice=np.array([0.8971, 0.8971, 0.8971, 0.8971, 0.8971, 0.8971, 0.8971, 0.8971, 0.8971]), + num_inference_steps=1, + processing_resolution=32, + ensemble_size=1, + batch_size=1, + match_input_resolution=True, + ) + + def test_marigold_normals_einstein_f32_cuda_G0_S1_P768_E1_B1_M1(self): + self._test_marigold_normals( + is_fp16=False, + device="cuda", + generator_seed=0, + expected_slice=np.array([0.7980, 0.7952, 0.7914, 0.7931, 0.7871, 0.7816, 0.7844, 0.7710, 0.7601]), + num_inference_steps=1, + processing_resolution=768, + ensemble_size=1, + batch_size=1, + match_input_resolution=True, + ) + + def test_marigold_normals_einstein_f16_cuda_G0_S1_P768_E1_B1_M1(self): + self._test_marigold_normals( + is_fp16=True, + device="cuda", + generator_seed=0, + expected_slice=np.array([0.7979, 0.7949, 0.7915, 0.7930, 0.7871, 0.7817, 0.7842, 0.7710, 0.7603]), + num_inference_steps=1, + processing_resolution=768, + ensemble_size=1, + batch_size=1, + match_input_resolution=True, + ) + + def test_marigold_normals_einstein_f16_cuda_G2024_S1_P768_E1_B1_M1(self): + self._test_marigold_normals( + is_fp16=True, + device="cuda", + generator_seed=2024, + expected_slice=np.array([0.8428, 0.8428, 0.8433, 0.8369, 0.8325, 0.8315, 0.8271, 0.8135, 0.8057]), + num_inference_steps=1, + processing_resolution=768, + ensemble_size=1, + batch_size=1, + match_input_resolution=True, + ) + + def test_marigold_normals_einstein_f16_cuda_G0_S2_P768_E1_B1_M1(self): + self._test_marigold_normals( + is_fp16=True, + device="cuda", + generator_seed=0, + expected_slice=np.array([0.7095, 0.7095, 0.7104, 0.7070, 0.7051, 0.7061, 0.7017, 0.6938, 0.6914]), + num_inference_steps=2, + processing_resolution=768, + ensemble_size=1, + batch_size=1, + match_input_resolution=True, + ) + + def test_marigold_normals_einstein_f16_cuda_G0_S1_P512_E1_B1_M1(self): + self._test_marigold_normals( + is_fp16=True, + device="cuda", + generator_seed=0, + expected_slice=np.array([0.7168, 0.7163, 0.7163, 0.7080, 0.7061, 0.7046, 0.7031, 0.7007, 0.6987]), + num_inference_steps=1, + processing_resolution=512, + ensemble_size=1, + batch_size=1, + match_input_resolution=True, + ) + + def test_marigold_normals_einstein_f16_cuda_G0_S1_P768_E3_B1_M1(self): + self._test_marigold_normals( + is_fp16=True, + device="cuda", + generator_seed=0, + expected_slice=np.array([0.7114, 0.7124, 0.7144, 0.7085, 0.7070, 0.7080, 0.7051, 0.6958, 0.6924]), + num_inference_steps=1, + processing_resolution=768, + ensemble_size=3, + ensembling_kwargs={"reduction": "mean"}, + batch_size=1, + match_input_resolution=True, + ) + + def test_marigold_normals_einstein_f16_cuda_G0_S1_P768_E4_B2_M1(self): + self._test_marigold_normals( + is_fp16=True, + device="cuda", + generator_seed=0, + expected_slice=np.array([0.7412, 0.7441, 0.7490, 0.7383, 0.7388, 0.7437, 0.7329, 0.7271, 0.7300]), + num_inference_steps=1, + processing_resolution=768, + ensemble_size=4, + ensembling_kwargs={"reduction": "mean"}, + batch_size=2, + match_input_resolution=True, + ) + + def test_marigold_normals_einstein_f16_cuda_G0_S1_P512_E1_B1_M0(self): + self._test_marigold_normals( + is_fp16=True, + device="cuda", + generator_seed=0, + expected_slice=np.array([0.7188, 0.7144, 0.7134, 0.7178, 0.7207, 0.7222, 0.7231, 0.7041, 0.6987]), + num_inference_steps=1, + processing_resolution=512, + ensemble_size=1, + batch_size=1, + match_input_resolution=False, + )