Skip to content

Marigold Update: v1-1 models, Intrinsic Image Decomposition pipeline, documentation #10884

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
Feb 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
a490417
minor documentation fixes of the depth and normals pipelines
toshas Feb 17, 2025
d002d5a
update license headers
toshas Feb 17, 2025
40e2040
update model checkpoints in examples
toshas Feb 17, 2025
2f0bbbd
add initial marigold intrinsics pipeline
toshas Feb 19, 2025
d20642c
update uncertainty visualization to work with intrinsics
toshas Feb 20, 2025
e007863
integrate iid
toshas Feb 20, 2025
6b5267f
add marigold intrinsics tests
toshas Feb 22, 2025
caa4a62
update documentation
toshas Feb 23, 2025
5a4196e
Merge branch 'main' into pipeline_marigold_intrinsics
toshas Feb 23, 2025
b83c0a3
Merge branch 'main' into pipeline_marigold_intrinsics
yiyixuxu Feb 24, 2025
dc9cd47
Update docs/source/en/api/pipelines/marigold.md
toshas Feb 24, 2025
d15a951
Update docs/source/en/using-diffusers/marigold_usage.md
toshas Feb 24, 2025
8972a40
Update docs/source/en/using-diffusers/marigold_usage.md
toshas Feb 24, 2025
13204ad
Update docs/source/en/using-diffusers/marigold_usage.md
toshas Feb 24, 2025
fd57911
Update docs/source/en/using-diffusers/marigold_usage.md
toshas Feb 24, 2025
149d4e9
Update docs/source/en/using-diffusers/marigold_usage.md
toshas Feb 24, 2025
2cff146
Update docs/source/en/using-diffusers/marigold_usage.md
toshas Feb 24, 2025
05c60a1
Update docs/source/en/using-diffusers/marigold_usage.md
toshas Feb 24, 2025
8364655
Update docs/source/en/using-diffusers/marigold_usage.md
toshas Feb 24, 2025
d6ff273
Update marigold.md
toshas Feb 24, 2025
2ba8d4d
add torch.compiler.disable to progress_bar to keep it in sync with th…
toshas Feb 24, 2025
d444508
make possible to instantiate the pipeline without vae and unet
toshas Feb 24, 2025
b3d1152
revert to having n_targets as a pipeline property
toshas Feb 24, 2025
3ecbcd7
minor depth ensembling fixes
toshas Feb 25, 2025
9489028
improve api documentation structure
toshas Feb 25, 2025
68be852
attempt at fixing latex jammed into preceding words
toshas Feb 25, 2025
1e97e03
enhance marigold usage section on acceleration
toshas Feb 25, 2025
c5c5e12
improve marigold video processing example
toshas Feb 25, 2025
54f67f5
improve tutorial conclusion
toshas Feb 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 89 additions & 34 deletions docs/source/en/api/pipelines/marigold.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/source/en/api/pipelines/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an
| [Latte](latte) | text2image |
| [LEDITS++](ledits_pp) | image editing |
| [Lumina-T2X](lumina) | text2image |
| [Marigold](marigold) | depth |
| [Marigold](marigold) | depth-estimation, normals-estimation, intrinsic-decomposition |
| [MultiDiffusion](panorama) | text2image |
| [MusicLDM](musicldm) | text2audio |
| [PAG](pag) | text2image |
Expand Down
485 changes: 312 additions & 173 deletions docs/source/en/using-diffusers/marigold_usage.md

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@
"Lumina2Text2ImgPipeline",
"LuminaText2ImgPipeline",
"MarigoldDepthPipeline",
"MarigoldIntrinsicsPipeline",
"MarigoldNormalsPipeline",
"MochiPipeline",
"MusicLDMPipeline",
Expand Down Expand Up @@ -845,6 +846,7 @@
Lumina2Text2ImgPipeline,
LuminaText2ImgPipeline,
MarigoldDepthPipeline,
MarigoldIntrinsicsPipeline,
MarigoldNormalsPipeline,
MochiPipeline,
MusicLDMPipeline,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@
_import_structure["marigold"].extend(
[
"MarigoldDepthPipeline",
"MarigoldIntrinsicsPipeline",
"MarigoldNormalsPipeline",
]
)
Expand Down Expand Up @@ -603,6 +604,7 @@
from .lumina2 import Lumina2Text2ImgPipeline
from .marigold import (
MarigoldDepthPipeline,
MarigoldIntrinsicsPipeline,
MarigoldNormalsPipeline,
)
from .mochi import MochiPipeline
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/marigold/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
else:
_import_structure["marigold_image_processing"] = ["MarigoldImageProcessor"]
_import_structure["pipeline_marigold_depth"] = ["MarigoldDepthOutput", "MarigoldDepthPipeline"]
_import_structure["pipeline_marigold_intrinsics"] = ["MarigoldIntrinsicsOutput", "MarigoldIntrinsicsPipeline"]
_import_structure["pipeline_marigold_normals"] = ["MarigoldNormalsOutput", "MarigoldNormalsPipeline"]

if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Expand All @@ -35,6 +36,7 @@
else:
from .marigold_image_processing import MarigoldImageProcessor
from .pipeline_marigold_depth import MarigoldDepthOutput, MarigoldDepthPipeline
from .pipeline_marigold_intrinsics import MarigoldIntrinsicsOutput, MarigoldIntrinsicsPipeline
from .pipeline_marigold_normals import MarigoldNormalsOutput, MarigoldNormalsPipeline

else:
Expand Down
141 changes: 127 additions & 14 deletions src/diffusers/pipelines/marigold/marigold_image_processing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,22 @@
from typing import List, Optional, Tuple, Union
# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
# Copyright 2024-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# --------------------------------------------------------------------------
# More information and citation instructions are available on the
# Marigold project website: https://marigoldcomputervision.github.io
# --------------------------------------------------------------------------
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import PIL
Expand Down Expand Up @@ -379,7 +397,7 @@ def visualize_depth(
val_min: float = 0.0,
val_max: float = 1.0,
color_map: str = "Spectral",
) -> Union[PIL.Image.Image, List[PIL.Image.Image]]:
) -> List[PIL.Image.Image]:
"""
Visualizes depth maps, such as predictions of the `MarigoldDepthPipeline`.

Expand All @@ -391,7 +409,7 @@ def visualize_depth(
color_map (`str`, *optional*, defaults to `"Spectral"`): Color map used to convert a single-channel
depth prediction into colored representation.

Returns: `PIL.Image.Image` or `List[PIL.Image.Image]` with depth maps visualization.
Returns: `List[PIL.Image.Image]` with depth maps visualization.
"""
if val_max <= val_min:
raise ValueError(f"Invalid values range: [{val_min}, {val_max}].")
Expand Down Expand Up @@ -436,7 +454,7 @@ def export_depth_to_16bit_png(
depth: Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]],
val_min: float = 0.0,
val_max: float = 1.0,
) -> Union[PIL.Image.Image, List[PIL.Image.Image]]:
) -> List[PIL.Image.Image]:
def export_depth_to_16bit_png_one(img, idx=None):
prefix = "Depth" + (f"[{idx}]" if idx else "")
if not isinstance(img, np.ndarray) and not torch.is_tensor(img):
Expand Down Expand Up @@ -478,7 +496,7 @@ def visualize_normals(
flip_x: bool = False,
flip_y: bool = False,
flip_z: bool = False,
) -> Union[PIL.Image.Image, List[PIL.Image.Image]]:
) -> List[PIL.Image.Image]:
"""
Visualizes surface normals, such as predictions of the `MarigoldNormalsPipeline`.

Expand All @@ -492,7 +510,7 @@ def visualize_normals(
flip_z (`bool`, *optional*, defaults to `False`): Flips the Z axis of the normals frame of reference.
Default direction is facing the observer.

Returns: `PIL.Image.Image` or `List[PIL.Image.Image]` with surface normals visualization.
Returns: `List[PIL.Image.Image]` with surface normals visualization.
"""
flip_vec = None
if any((flip_x, flip_y, flip_z)):
Expand Down Expand Up @@ -528,6 +546,99 @@ def visualize_normals_one(img, idx=None):
else:
raise ValueError(f"Unexpected input type: {type(normals)}")

@staticmethod
def visualize_intrinsics(
prediction: Union[
np.ndarray,
torch.Tensor,
List[np.ndarray],
List[torch.Tensor],
],
target_properties: Dict[str, Any],
color_map: Union[str, Dict[str, str]] = "binary",
) -> List[Dict[str, PIL.Image.Image]]:
"""
Visualizes intrinsic image decomposition, such as predictions of the `MarigoldIntrinsicsPipeline`.

Args:
prediction (`Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]]`):
Intrinsic image decomposition.
target_properties (`Dict[str, Any]`):
Decomposition properties. Expected entries: `target_names: List[str]` and a dictionary with keys
`prediction_space: str`, `sub_target_names: List[Union[str, Null]]` (must have 3 entries, null for
missing modalities), `up_to_scale: bool`, one for each target and sub-target.
color_map (`Union[str, Dict[str, str]]`, *optional*, defaults to `"Spectral"`):
Color map used to convert a single-channel predictions into colored representations. When a dictionary
is passed, each modality can be colored with its own color map.

Returns: `List[Dict[str, PIL.Image.Image]]` with intrinsic image decomposition visualization.
"""
if "target_names" not in target_properties:
raise ValueError("Missing `target_names` in target_properties")
if not isinstance(color_map, str) and not (
isinstance(color_map, dict)
and all(isinstance(k, str) and isinstance(v, str) for k, v in color_map.items())
):
raise ValueError("`color_map` must be a string or a dictionary of strings")
n_targets = len(target_properties["target_names"])

def visualize_targets_one(images, idx=None):
# img: [T, 3, H, W]
out = {}
for target_name, img in zip(target_properties["target_names"], images):
img = img.permute(1, 2, 0) # [H, W, 3]
prediction_space = target_properties[target_name].get("prediction_space", "srgb")
if prediction_space == "stack":
sub_target_names = target_properties[target_name]["sub_target_names"]
if len(sub_target_names) != 3 or any(
not (isinstance(s, str) or s is None) for s in sub_target_names
):
raise ValueError(f"Unexpected target sub-names {sub_target_names} in {target_name}")
for i, sub_target_name in enumerate(sub_target_names):
if sub_target_name is None:
continue
sub_img = img[:, :, i]
sub_prediction_space = target_properties[sub_target_name].get("prediction_space", "srgb")
if sub_prediction_space == "linear":
sub_up_to_scale = target_properties[sub_target_name].get("up_to_scale", False)
if sub_up_to_scale:
sub_img = sub_img / max(sub_img.max().item(), 1e-6)
sub_img = sub_img ** (1 / 2.2)
cmap_name = (
color_map if isinstance(color_map, str) else color_map.get(sub_target_name, "binary")
)
sub_img = MarigoldImageProcessor.colormap(sub_img, cmap=cmap_name, bytes=True)
sub_img = PIL.Image.fromarray(sub_img.cpu().numpy())
out[sub_target_name] = sub_img
elif prediction_space == "linear":
up_to_scale = target_properties[target_name].get("up_to_scale", False)
if up_to_scale:
img = img / max(img.max().item(), 1e-6)
img = img ** (1 / 2.2)
elif prediction_space == "srgb":
pass
img = (img * 255).to(dtype=torch.uint8, device="cpu").numpy()
img = PIL.Image.fromarray(img)
out[target_name] = img
return out

if prediction is None or isinstance(prediction, list) and any(o is None for o in prediction):
raise ValueError("Input prediction is `None`")
if isinstance(prediction, (np.ndarray, torch.Tensor)):
prediction = MarigoldImageProcessor.expand_tensor_or_array(prediction)
if isinstance(prediction, np.ndarray):
prediction = MarigoldImageProcessor.numpy_to_pt(prediction) # [N*T,3,H,W]
if not (prediction.ndim == 4 and prediction.shape[1] == 3 and prediction.shape[0] % n_targets == 0):
raise ValueError(f"Unexpected input shape={prediction.shape}, expecting [N*T,3,H,W].")
N_T, _, H, W = prediction.shape
N = N_T // n_targets
prediction = prediction.reshape(N, n_targets, 3, H, W)
return [visualize_targets_one(img, idx) for idx, img in enumerate(prediction)]
elif isinstance(prediction, list):
return [visualize_targets_one(img, idx) for idx, img in enumerate(prediction)]
else:
raise ValueError(f"Unexpected input type: {type(prediction)}")

@staticmethod
def visualize_uncertainty(
uncertainty: Union[
Expand All @@ -537,24 +648,26 @@ def visualize_uncertainty(
List[torch.Tensor],
],
saturation_percentile=95,
) -> Union[PIL.Image.Image, List[PIL.Image.Image]]:
) -> List[PIL.Image.Image]:
"""
Visualizes dense uncertainties, such as produced by `MarigoldDepthPipeline` or `MarigoldNormalsPipeline`.
Visualizes dense uncertainties, such as produced by `MarigoldDepthPipeline`, `MarigoldNormalsPipeline`, or
`MarigoldIntrinsicsPipeline`.

Args:
uncertainty (`Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]]`):
Uncertainty maps.
saturation_percentile (`int`, *optional*, defaults to `95`):
Specifies the percentile uncertainty value visualized with maximum intensity.

Returns: `PIL.Image.Image` or `List[PIL.Image.Image]` with uncertainty visualization.
Returns: `List[PIL.Image.Image]` with uncertainty visualization.
"""

def visualize_uncertainty_one(img, idx=None):
prefix = "Uncertainty" + (f"[{idx}]" if idx else "")
if img.min() < 0:
raise ValueError(f"{prefix}: unexected data range, min={img.min()}.")
img = img.squeeze(0).cpu().numpy()
raise ValueError(f"{prefix}: unexpected data range, min={img.min()}.")
img = img.permute(1, 2, 0) # [H,W,C]
img = img.squeeze(2).cpu().numpy() # [H,W] or [H,W,3]
saturation_value = np.percentile(img, saturation_percentile)
img = np.clip(img * 255 / saturation_value, 0, 255)
img = img.astype(np.uint8)
Expand All @@ -566,9 +679,9 @@ def visualize_uncertainty_one(img, idx=None):
if isinstance(uncertainty, (np.ndarray, torch.Tensor)):
uncertainty = MarigoldImageProcessor.expand_tensor_or_array(uncertainty)
if isinstance(uncertainty, np.ndarray):
uncertainty = MarigoldImageProcessor.numpy_to_pt(uncertainty) # [N,1,H,W]
if not (uncertainty.ndim == 4 and uncertainty.shape[1] == 1):
raise ValueError(f"Unexpected input shape={uncertainty.shape}, expecting [N,1,H,W].")
uncertainty = MarigoldImageProcessor.numpy_to_pt(uncertainty) # [N,C,H,W]
if not (uncertainty.ndim == 4 and uncertainty.shape[1] in (1, 3)):
raise ValueError(f"Unexpected input shape={uncertainty.shape}, expecting [N,C,H,W] with C in (1,3).")
return [visualize_uncertainty_one(img, idx) for idx, img in enumerate(uncertainty)]
elif isinstance(uncertainty, list):
return [visualize_uncertainty_one(img, idx) for idx, img in enumerate(uncertainty)]
Expand Down
34 changes: 19 additions & 15 deletions src/diffusers/pipelines/marigold/pipeline_marigold_depth.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright 2024 Marigold authors, PRS ETH Zurich. All rights reserved.
# Copyright 2024 The HuggingFace Team. All rights reserved.
# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
# Copyright 2024-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -14,7 +14,7 @@
# limitations under the License.
# --------------------------------------------------------------------------
# More information and citation instructions are available on the
# Marigold project website: https://marigoldmonodepth.github.io
# Marigold project website: https://marigoldcomputervision.github.io
# --------------------------------------------------------------------------
from dataclasses import dataclass
from functools import partial
Expand Down Expand Up @@ -64,7 +64,7 @@
>>> import torch

>>> pipe = diffusers.MarigoldDepthPipeline.from_pretrained(
... "prs-eth/marigold-depth-lcm-v1-0", variant="fp16", torch_dtype=torch.float16
... "prs-eth/marigold-depth-v1-1", variant="fp16", torch_dtype=torch.float16
... ).to("cuda")

>>> image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
Expand All @@ -86,11 +86,12 @@ class MarigoldDepthOutput(BaseOutput):

Args:
prediction (`np.ndarray`, `torch.Tensor`):
Predicted depth maps with values in the range [0, 1]. The shape is always $numimages \times 1 \times height
\times width$, regardless of whether the images were passed as a 4D array or a list.
Predicted depth maps with values in the range [0, 1]. The shape is $numimages \times 1 \times height \times
width$ for `torch.Tensor` or $numimages \times height \times width \times 1$ for `np.ndarray`.
uncertainty (`None`, `np.ndarray`, `torch.Tensor`):
Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is $numimages
\times 1 \times height \times width$.
\times 1 \times height \times width$ for `torch.Tensor` or $numimages \times height \times width \times 1$
for `np.ndarray`.
latent (`None`, `torch.Tensor`):
Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline.
The shape is $numimages * numensemble \times 4 \times latentheight \times latentwidth$.
Expand Down Expand Up @@ -208,6 +209,11 @@ def check_inputs(
output_type: str,
output_uncertainty: bool,
) -> int:
actual_vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
if actual_vae_scale_factor != self.vae_scale_factor:
raise ValueError(
f"`vae_scale_factor` computed at initialization ({self.vae_scale_factor}) differs from the actual one ({actual_vae_scale_factor})."
)
if num_inference_steps is None:
raise ValueError("`num_inference_steps` is not specified and could not be resolved from the model config.")
if num_inference_steps < 1:
Expand Down Expand Up @@ -320,6 +326,7 @@ def check_inputs(

return num_images

@torch.compiler.disable
def progress_bar(self, iterable=None, total=None, desc=None, leave=True):
if not hasattr(self, "_progress_bar_config"):
self._progress_bar_config = {}
Expand Down Expand Up @@ -370,11 +377,9 @@ def __call__(
same width and height.
num_inference_steps (`int`, *optional*, defaults to `None`):
Number of denoising diffusion steps during inference. The default value `None` results in automatic
selection. The number of steps should be at least 10 with the full Marigold models, and between 1 and 4
for Marigold-LCM models.
selection.
ensemble_size (`int`, defaults to `1`):
Number of ensemble predictions. Recommended values are 5 and higher for better precision, or 1 for
faster inference.
Number of ensemble predictions. Higher values result in measurable improvements and visual degradation.
processing_resolution (`int`, *optional*, defaults to `None`):
Effective processing resolution. When set to `0`, matches the larger input image dimension. This
produces crisper predictions, but may also lead to the overall loss of global context. The default
Expand Down Expand Up @@ -486,9 +491,7 @@ def __call__(
# `pred_latent` variable. The variable `image_latent` is of the same shape: it contains each input image encoded
# into latent space and replicated `E` times. The latents can be either generated (see `generator` to ensure
# reproducibility), or passed explicitly via the `latents` argument. The latter can be set outside the pipeline
# code. For example, in the Marigold-LCM video processing demo, the latents initialization of a frame is taken
# as a convex combination of the latents output of the pipeline for the previous frame and a newly-sampled
# noise. This behavior can be achieved by setting the `output_latent` argument to `True`. The latent space
# code. This behavior can be achieved by setting the `output_latent` argument to `True`. The latent space
# dimensions are `(h, w)`. Encoding into latent space happens in batches of size `batch_size`.
# Model invocation: self.vae.encoder.
image_latent, pred_latent = self.prepare_latents(
Expand Down Expand Up @@ -733,6 +736,7 @@ def init_param(depth: torch.Tensor):
param = init_s.cpu().numpy()
else:
raise ValueError("Unrecognized alignment.")
param = param.astype(np.float64)

return param

Expand Down Expand Up @@ -775,7 +779,7 @@ def cost_fn(param: np.ndarray, depth: torch.Tensor) -> float:

if regularizer_strength > 0:
prediction, _ = ensemble(depth_aligned, return_uncertainty=False)
err_near = (0.0 - prediction.min()).abs().item()
err_near = prediction.min().abs().item()
err_far = (1.0 - prediction.max()).abs().item()
cost += (err_near + err_far) * regularizer_strength

Expand Down
Loading