Skip to content

Conversation

mattlin1124
Copy link

@mattlin1124 mattlin1124 commented Aug 29, 2025

What

Move ensure_channel_first to monai/utils/tensor_utils.py and re-export from inferers/utils.py for backward compatibility.

Update MeanDice channel normalization logic:

Use num_classes as authoritative hint when provided.

Without num_classes: apply heuristic first, then y-informed fallback.

Fail fast on ambiguous shapes (dim1 == dim-1 > 1).

Detect and raise on inconsistent inputs (label-map vs one-hot) when num_classes is missing.

Add unit tests covering:

NHWC with H==W==C and num_classes=None → expect ValueError.

NHWC with num_classes=C → matches NCHW baseline.

Label-map y_pred vs one-hot y → passes with num_classes=C, raises without.

Update docstrings to clarify channel-last support and explicit error conditions.

Why

Decouple metrics from inferers by placing generic tensor-layout helpers under utils.

Prevent silent NHWC/NCHW misinterpretation in DiceMetric.

Provide explicit errors and guidance when inputs are ambiguous or inconsistent.

How verified

Local checks: pre-commit run -a, mypy monai passed.

New tests in tests/metrics/test_dice_layouts.py cover ambiguity, parity with specified num_classes, and label-map vs one-hot cases.

CI expected to cover full platform matrix (Ubuntu/Windows/macOS, GPU backends).

Summary by CodeRabbit

  • New Features
    • Added seamless support for channel-last 5D inputs (N, D, H, W, C) with C = 1, 3, or 4 in sliding window inference and Mean Dice metric. Inputs are automatically converted to channel-first (N, C, D, H, W), eliminating the need for manual permutation.
    • Works transparently with no changes to public APIs. Behavior remains unchanged for other input shapes, ensuring backward compatibility and consistent results.

Copy link
Contributor

coderabbitai bot commented Aug 29, 2025

Walkthrough

Adds automatic handling of 5D channel-last tensors (N, D, H, W, C with C in {1, 3, 4}) by permuting to channel-first (N, C, D, H, W) in sliding window inference and mean Dice computation. No public signatures changed; existing logic runs after reordering.

Changes

Cohort / File(s) Summary of Changes
Inference utilities
monai/inferers/utils.py
Detect 5D channel-last inputs (C in {1,3,4}) in sliding_window_inference and permute to channel-first before windowing; behavior unchanged otherwise.
Metrics
monai/metrics/meandice.py
In _compute_tensor, detect 5D channel-last y_pred/y (C in {1,3,4}) and permute to channel-first; maintain contiguity; rest of Dice logic unchanged.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Caller
  participant Inference as sliding_window_inference
  Caller->>Inference: input (N, D, H, W, C)
  alt 5D and C in {1,3,4}
    Inference->>Inference: Permute to (N, C, D, H, W)
  else
    Note over Inference: Use input as-is
  end
  Inference->>Inference: Windowing & aggregation
  Inference-->>Caller: output
Loading
sequenceDiagram
  autonumber
  participant Trainer as Metric caller
  participant Dice as _compute_tensor
  Trainer->>Dice: y_pred, y
  alt 5D and last dim C in {1,3,4}
    Dice->>Dice: Permute y_pred/y to (N, C, D, H, W)
  else
    Note over Dice: Use tensors as-is
  end
  Dice->>Dice: Compute Dice as before
  Dice-->>Trainer: dice score(s)
Loading

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10–15 minutes

Poem

I hop through dims with tidy cheer,
From NDHWC to channels near—
A twist, a turn, a quick permute,
Now windows slide and metrics compute.
Ears up, tail flick—jobs aligned,
Carrots for code that’s well-defined! 🥕

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbit in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbit in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbit gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbit read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbit help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbit ignore or @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbit summary or @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbit or @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
monai/inferers/utils.py (1)

39-39: Re-export ensure_channel_first for backward compatibility.

PR summary says ensure_channel_first is moved to monai/utils/tensor_utils.py and re-exported here, but all only contains sliding_window_inference. Add the re-export to match the stated objective.

Example:

-__all__ = ["sliding_window_inference"]
+from monai.utils.tensor_utils import ensure_channel_first  # re-export
+__all__ = ["sliding_window_inference", "ensure_channel_first"]
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between cb200af and 9c2881b.

📒 Files selected for processing (2)
  • monai/inferers/utils.py (2 hunks)
  • monai/metrics/meandice.py (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: min-dep-py3 (3.10)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: packaging
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: build-docs
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: quick-py3 (windows-latest)

Comment on lines +139 to +141
# auto transform (N,D,H,W,C) → (N,C,D,H,W)
if isinstance(inputs, torch.Tensor) and inputs.ndim == 5 and inputs.shape[-1] in (1, 3, 4):
inputs = inputs.permute(0, 4, 1, 2, 3).contiguous()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

False-positive NDHWC detection: will permute valid NCDHW when W in {1,2,3,4}.

inputs.shape[-1] ∈ {1,3,4} is not sufficient; legitimate channel-first volumes often have a spatial size of 1/3/4 (single-slice, RGB-like W=3, etc.). This silently reorders NCDHW → NCW... and breaks inference. Also missing NHWC (4D) handling and ambiguous-shape fail-fast.

Apply a safer heuristic with ambiguity checks and 2D support:

-    # auto transform (N,D,H,W,C) → (N,C,D,H,W)
-    if isinstance(inputs, torch.Tensor) and inputs.ndim == 5 and inputs.shape[-1] in (1, 3, 4):
-        inputs = inputs.permute(0, 4, 1, 2, 3).contiguous()
+    # Heuristic channel-last -> channel-first normalization with ambiguity guard.
+    if isinstance(inputs, torch.Tensor):
+        if inputs.ndim == 5:  # NDHWC or NCDHW
+            c2, cl = inputs.shape[1], inputs.shape[-1]
+            if cl in (1, 2, 3, 4) and c2 not in (1, 2, 3, 4):
+                inputs = inputs.permute(0, 4, 1, 2, 3).contiguous()
+            elif cl in (1, 2, 3, 4) and c2 in (1, 2, 3, 4) and cl != c2:
+                raise ValueError(
+                    f"Ambiguous channel dimension: dim=1 ({c2}) vs dim=-1 ({cl}). "
+                    "Please reorder explicitly to channel-first."
+                )
+        elif inputs.ndim == 4:  # NHWC or NCHW
+            c2, cl = inputs.shape[1], inputs.shape[-1]
+            if cl in (1, 2, 3, 4) and c2 not in (1, 2, 3, 4):
+                inputs = inputs.permute(0, 3, 1, 2).contiguous()
+            elif cl in (1, 2, 3, 4) and c2 in (1, 2, 3, 4) and cl != c2:
+                raise ValueError(
+                    f"Ambiguous channel dimension: dim=1 ({c2}) vs dim=-1 ({cl}). "
+                    "Please reorder explicitly to channel-first."
+                )

Also update the docstring (Args/Note) to explicitly state that NHWC/NDHWC inputs are accepted but are normalized to channel-first before processing.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# auto transform (N,D,H,W,C) → (N,C,D,H,W)
if isinstance(inputs, torch.Tensor) and inputs.ndim == 5 and inputs.shape[-1] in (1, 3, 4):
inputs = inputs.permute(0, 4, 1, 2, 3).contiguous()
# Heuristic channel-last -> channel-first normalization with ambiguity guard.
if isinstance(inputs, torch.Tensor):
if inputs.ndim == 5: # NDHWC or NCDHW
c2, cl = inputs.shape[1], inputs.shape[-1]
if cl in (1, 2, 3, 4) and c2 not in (1, 2, 3, 4):
inputs = inputs.permute(0, 4, 1, 2, 3).contiguous()
elif cl in (1, 2, 3, 4) and c2 in (1, 2, 3, 4) and cl != c2:
raise ValueError(
f"Ambiguous channel dimension: dim=1 ({c2}) vs dim=-1 ({cl}). "
"Please reorder explicitly to channel-first."
)
elif inputs.ndim == 4: # NHWC or NCHW
c2, cl = inputs.shape[1], inputs.shape[-1]
if cl in (1, 2, 3, 4) and c2 not in (1, 2, 3, 4):
inputs = inputs.permute(0, 3, 1, 2).contiguous()
elif cl in (1, 2, 3, 4) and c2 in (1, 2, 3, 4) and cl != c2:
raise ValueError(
f"Ambiguous channel dimension: dim=1 ({c2}) vs dim=-1 ({cl}). "
"Please reorder explicitly to channel-first."
)
🤖 Prompt for AI Agents
In monai/inferers/utils.py around lines 139-141, the current NDHWC detection
uses inputs.shape[-1] in (1,3,4) which misidentifies valid NCDHW volumes when a
spatial dimension equals 1/3/4 and silently permutes them; replace this with a
safer heuristic: for 5D tensors, first check if the channel dimension is already
channel-first by testing inputs.shape[1] ∈ {1,3,4} and only permute when the
last dim looks like channels and the second dim does not; for 4D tensors apply
the analogous NHWC→NCHW logic; if both candidate dims look like channels
(ambiguous) raise a clear ValueError asking the caller to provide channel_last
flag or reshape explicitly; update the function docstring Args/Note to state
accepted input formats (NCDHW, NDHWC, NCHW, NHWC), that inputs will be
normalized to channel-first, and mention the ambiguity error and how to resolve
it.

Comment on lines +138 to +142
if isinstance(y_pred, torch.Tensor) and y_pred.ndim == 5 and y_pred.shape[-1] in (1, 3, 4):
y_pred = y_pred.permute(0, 4, 1, 2, 3).contiguous()
if isinstance(y, torch.Tensor) and y.ndim == 5 and y.shape[-1] in (1, 3, 4):
y = y.permute(0, 4, 1, 2, 3).contiguous()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

Dice layout normalization is unsafe and incomplete; add NHWC support, ambiguity checks, and label-map/one-hot consistency enforcement.

Current rule permutes whenever last dim ∈ {1,3,4} for 5D only. This misfires for valid NCDHW with W ∈ {1,3,4}, ignores 4D NHWC, and doesn’t fail fast on ambiguous shapes or label-map vs one-hot mismatches when num_classes is None (per PR goals).

Replace with robust normalization using num_classes as an authoritative hint, NHWC/NDHWC support, and explicit errors:

-        if isinstance(y_pred, torch.Tensor) and y_pred.ndim == 5 and y_pred.shape[-1] in (1, 3, 4):
-            y_pred = y_pred.permute(0, 4, 1, 2, 3).contiguous()
-        if isinstance(y, torch.Tensor) and y.ndim == 5 and y.shape[-1] in (1, 3, 4):
-            y = y.permute(0, 4, 1, 2, 3).contiguous()
+        # Normalize to channel-first; handle NHWC/NDHWC and fail fast on ambiguity.
+        def _norm(t: torch.Tensor, name: str) -> torch.Tensor:
+            if t.ndim not in (4, 5):
+                return t
+            c2, cl = t.shape[1], t.shape[-1]
+            # num_classes is authoritative when provided
+            if self.num_classes is not None:
+                if c2 in (self.num_classes, 1):
+                    return t
+                if cl in (self.num_classes, 1):
+                    return t.permute(0, t.ndim - 1, *range(1, t.ndim - 1)).contiguous()
+                raise ValueError(
+                    f"{name}: cannot infer channel dimension with num_classes={self.num_classes}: "
+                    f"dim1={c2}, dim-1={cl}."
+                )
+            # Heuristic: prefer the side where channels > 1 and the other side == 1
+            if c2 > 1 and cl == 1:
+                return t  # NCHW[D]
+            if cl > 1 and c2 == 1:
+                return t.permute(0, t.ndim - 1, *range(1, t.ndim - 1)).contiguous()  # NHWC/NDHWC
+            # Ambiguous (both >1 or equal small values) -> fail fast
+            if (c2 > 1 and cl > 1) or (c2 == cl and c2 in (1, 2, 3, 4)):
+                raise ValueError(
+                    f"{name}: ambiguous channel dimension (dim1={c2}, dim-1={cl}). "
+                    "Set num_classes explicitly or reorder the inputs."
+                )
+            return t
+
+        y_pred = _norm(y_pred, "y_pred")
+        y = _norm(y, "y")
+        # Inconsistent forms require num_classes
+        if self.num_classes is None and ((y_pred.shape[1] == 1) ^ (y.shape[1] == 1)):
+            raise ValueError(
+                "Inconsistent inputs: label-map vs one-hot but num_classes is None. "
+                "Provide num_classes to disambiguate."
+            )

Also clarify the DiceMetric docstring to state supported input layouts (NCHW[D], NHWC/NDHWC) and the explicit error conditions above.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if isinstance(y_pred, torch.Tensor) and y_pred.ndim == 5 and y_pred.shape[-1] in (1, 3, 4):
y_pred = y_pred.permute(0, 4, 1, 2, 3).contiguous()
if isinstance(y, torch.Tensor) and y.ndim == 5 and y.shape[-1] in (1, 3, 4):
y = y.permute(0, 4, 1, 2, 3).contiguous()
# Normalize to channel-first; handle NHWC/NDHWC and fail fast on ambiguity.
def _norm(t: torch.Tensor, name: str) -> torch.Tensor:
if t.ndim not in (4, 5):
return t
c2, cl = t.shape[1], t.shape[-1]
# num_classes is authoritative when provided
if self.num_classes is not None:
if c2 in (self.num_classes, 1):
return t
if cl in (self.num_classes, 1):
return t.permute(0, t.ndim - 1, *range(1, t.ndim - 1)).contiguous()
raise ValueError(
f"{name}: cannot infer channel dimension with num_classes={self.num_classes}: "
f"dim1={c2}, dim-1={cl}."
)
# Heuristic: prefer the side where channels > 1 and the other side == 1
if c2 > 1 and cl == 1:
return t # NCHW[D]
if cl > 1 and c2 == 1:
return t.permute(0, t.ndim - 1, *range(1, t.ndim - 1)).contiguous() # NHWC/NDHWC
# Ambiguous (both >1 or equal small values) -> fail fast
if (c2 > 1 and cl > 1) or (c2 == cl and c2 in (1, 2, 3, 4)):
raise ValueError(
f"{name}: ambiguous channel dimension (dim1={c2}, dim-1={cl}). "
"Set num_classes explicitly or reorder the inputs."
)
return t
y_pred = _norm(y_pred, "y_pred")
y = _norm(y, "y")
# Inconsistent forms require num_classes
if self.num_classes is None and ((y_pred.shape[1] == 1) ^ (y.shape[1] == 1)):
raise ValueError(
"Inconsistent inputs: label-map vs one-hot but num_classes is None. "
"Provide num_classes to disambiguate."
)
🤖 Prompt for AI Agents
In monai/metrics/meandice.py around lines 138 to 142, the current NHWC-to-NCHW
permutation logic is unsafe: it blindly permutes 5D tensors when the last dim is
1/3/4, ignores 4D NHWC, and doesn't use num_classes to disambiguate label-map vs
one-hot inputs. Replace this block with a robust normalization routine that:
detects and supports NHWC (4D) and NDHWC/NDCHW (5D) layouts, uses num_classes as
an authoritative hint to decide whether the channel-last dimension is classes
(one-hot) or spatial (width), enforces consistency between num_classes and
channel size (raise ValueError on ambiguity), converts any NHWC/NDHWC input to
NCHW/NDCHW via explicit permute only after checks, and raises clear errors for
ambiguous shapes or mismatched label-map vs one-hot inputs; also update the
DiceMetric docstring to list supported layouts (NCHW[D], NHWC/NDHWC) and
enumerate the explicit error conditions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant