-
Notifications
You must be signed in to change notification settings - Fork 1.3k
BUGFIX: support NDHWC input in sliding_window_inference and DiceMetric #8550
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
WalkthroughAdds automatic handling of 5D channel-last tensors (N, D, H, W, C with C in {1, 3, 4}) by permuting to channel-first (N, C, D, H, W) in sliding window inference and mean Dice computation. No public signatures changed; existing logic runs after reordering. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Caller
participant Inference as sliding_window_inference
Caller->>Inference: input (N, D, H, W, C)
alt 5D and C in {1,3,4}
Inference->>Inference: Permute to (N, C, D, H, W)
else
Note over Inference: Use input as-is
end
Inference->>Inference: Windowing & aggregation
Inference-->>Caller: output
sequenceDiagram
autonumber
participant Trainer as Metric caller
participant Dice as _compute_tensor
Trainer->>Dice: y_pred, y
alt 5D and last dim C in {1,3,4}
Dice->>Dice: Permute y_pred/y to (N, C, D, H, W)
else
Note over Dice: Use tensors as-is
end
Dice->>Dice: Compute Dice as before
Dice-->>Trainer: dice score(s)
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10–15 minutes Poem
✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
CodeRabbit Configuration File (
|
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (1)
monai/inferers/utils.py (1)
39-39
: Re-export ensure_channel_first for backward compatibility.PR summary says ensure_channel_first is moved to monai/utils/tensor_utils.py and re-exported here, but all only contains sliding_window_inference. Add the re-export to match the stated objective.
Example:
-__all__ = ["sliding_window_inference"] +from monai.utils.tensor_utils import ensure_channel_first # re-export +__all__ = ["sliding_window_inference", "ensure_channel_first"]
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (2)
monai/inferers/utils.py
(2 hunks)monai/metrics/meandice.py
(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: packaging
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: build-docs
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: quick-py3 (windows-latest)
# auto transform (N,D,H,W,C) → (N,C,D,H,W) | ||
if isinstance(inputs, torch.Tensor) and inputs.ndim == 5 and inputs.shape[-1] in (1, 3, 4): | ||
inputs = inputs.permute(0, 4, 1, 2, 3).contiguous() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
False-positive NDHWC detection: will permute valid NCDHW when W in {1,2,3,4}.
inputs.shape[-1] ∈ {1,3,4} is not sufficient; legitimate channel-first volumes often have a spatial size of 1/3/4 (single-slice, RGB-like W=3, etc.). This silently reorders NCDHW → NCW... and breaks inference. Also missing NHWC (4D) handling and ambiguous-shape fail-fast.
Apply a safer heuristic with ambiguity checks and 2D support:
- # auto transform (N,D,H,W,C) → (N,C,D,H,W)
- if isinstance(inputs, torch.Tensor) and inputs.ndim == 5 and inputs.shape[-1] in (1, 3, 4):
- inputs = inputs.permute(0, 4, 1, 2, 3).contiguous()
+ # Heuristic channel-last -> channel-first normalization with ambiguity guard.
+ if isinstance(inputs, torch.Tensor):
+ if inputs.ndim == 5: # NDHWC or NCDHW
+ c2, cl = inputs.shape[1], inputs.shape[-1]
+ if cl in (1, 2, 3, 4) and c2 not in (1, 2, 3, 4):
+ inputs = inputs.permute(0, 4, 1, 2, 3).contiguous()
+ elif cl in (1, 2, 3, 4) and c2 in (1, 2, 3, 4) and cl != c2:
+ raise ValueError(
+ f"Ambiguous channel dimension: dim=1 ({c2}) vs dim=-1 ({cl}). "
+ "Please reorder explicitly to channel-first."
+ )
+ elif inputs.ndim == 4: # NHWC or NCHW
+ c2, cl = inputs.shape[1], inputs.shape[-1]
+ if cl in (1, 2, 3, 4) and c2 not in (1, 2, 3, 4):
+ inputs = inputs.permute(0, 3, 1, 2).contiguous()
+ elif cl in (1, 2, 3, 4) and c2 in (1, 2, 3, 4) and cl != c2:
+ raise ValueError(
+ f"Ambiguous channel dimension: dim=1 ({c2}) vs dim=-1 ({cl}). "
+ "Please reorder explicitly to channel-first."
+ )
Also update the docstring (Args/Note) to explicitly state that NHWC/NDHWC inputs are accepted but are normalized to channel-first before processing.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
# auto transform (N,D,H,W,C) → (N,C,D,H,W) | |
if isinstance(inputs, torch.Tensor) and inputs.ndim == 5 and inputs.shape[-1] in (1, 3, 4): | |
inputs = inputs.permute(0, 4, 1, 2, 3).contiguous() | |
# Heuristic channel-last -> channel-first normalization with ambiguity guard. | |
if isinstance(inputs, torch.Tensor): | |
if inputs.ndim == 5: # NDHWC or NCDHW | |
c2, cl = inputs.shape[1], inputs.shape[-1] | |
if cl in (1, 2, 3, 4) and c2 not in (1, 2, 3, 4): | |
inputs = inputs.permute(0, 4, 1, 2, 3).contiguous() | |
elif cl in (1, 2, 3, 4) and c2 in (1, 2, 3, 4) and cl != c2: | |
raise ValueError( | |
f"Ambiguous channel dimension: dim=1 ({c2}) vs dim=-1 ({cl}). " | |
"Please reorder explicitly to channel-first." | |
) | |
elif inputs.ndim == 4: # NHWC or NCHW | |
c2, cl = inputs.shape[1], inputs.shape[-1] | |
if cl in (1, 2, 3, 4) and c2 not in (1, 2, 3, 4): | |
inputs = inputs.permute(0, 3, 1, 2).contiguous() | |
elif cl in (1, 2, 3, 4) and c2 in (1, 2, 3, 4) and cl != c2: | |
raise ValueError( | |
f"Ambiguous channel dimension: dim=1 ({c2}) vs dim=-1 ({cl}). " | |
"Please reorder explicitly to channel-first." | |
) |
🤖 Prompt for AI Agents
In monai/inferers/utils.py around lines 139-141, the current NDHWC detection
uses inputs.shape[-1] in (1,3,4) which misidentifies valid NCDHW volumes when a
spatial dimension equals 1/3/4 and silently permutes them; replace this with a
safer heuristic: for 5D tensors, first check if the channel dimension is already
channel-first by testing inputs.shape[1] ∈ {1,3,4} and only permute when the
last dim looks like channels and the second dim does not; for 4D tensors apply
the analogous NHWC→NCHW logic; if both candidate dims look like channels
(ambiguous) raise a clear ValueError asking the caller to provide channel_last
flag or reshape explicitly; update the function docstring Args/Note to state
accepted input formats (NCDHW, NDHWC, NCHW, NHWC), that inputs will be
normalized to channel-first, and mention the ambiguity error and how to resolve
it.
if isinstance(y_pred, torch.Tensor) and y_pred.ndim == 5 and y_pred.shape[-1] in (1, 3, 4): | ||
y_pred = y_pred.permute(0, 4, 1, 2, 3).contiguous() | ||
if isinstance(y, torch.Tensor) and y.ndim == 5 and y.shape[-1] in (1, 3, 4): | ||
y = y.permute(0, 4, 1, 2, 3).contiguous() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Dice layout normalization is unsafe and incomplete; add NHWC support, ambiguity checks, and label-map/one-hot consistency enforcement.
Current rule permutes whenever last dim ∈ {1,3,4} for 5D only. This misfires for valid NCDHW with W ∈ {1,3,4}, ignores 4D NHWC, and doesn’t fail fast on ambiguous shapes or label-map vs one-hot mismatches when num_classes is None (per PR goals).
Replace with robust normalization using num_classes as an authoritative hint, NHWC/NDHWC support, and explicit errors:
- if isinstance(y_pred, torch.Tensor) and y_pred.ndim == 5 and y_pred.shape[-1] in (1, 3, 4):
- y_pred = y_pred.permute(0, 4, 1, 2, 3).contiguous()
- if isinstance(y, torch.Tensor) and y.ndim == 5 and y.shape[-1] in (1, 3, 4):
- y = y.permute(0, 4, 1, 2, 3).contiguous()
+ # Normalize to channel-first; handle NHWC/NDHWC and fail fast on ambiguity.
+ def _norm(t: torch.Tensor, name: str) -> torch.Tensor:
+ if t.ndim not in (4, 5):
+ return t
+ c2, cl = t.shape[1], t.shape[-1]
+ # num_classes is authoritative when provided
+ if self.num_classes is not None:
+ if c2 in (self.num_classes, 1):
+ return t
+ if cl in (self.num_classes, 1):
+ return t.permute(0, t.ndim - 1, *range(1, t.ndim - 1)).contiguous()
+ raise ValueError(
+ f"{name}: cannot infer channel dimension with num_classes={self.num_classes}: "
+ f"dim1={c2}, dim-1={cl}."
+ )
+ # Heuristic: prefer the side where channels > 1 and the other side == 1
+ if c2 > 1 and cl == 1:
+ return t # NCHW[D]
+ if cl > 1 and c2 == 1:
+ return t.permute(0, t.ndim - 1, *range(1, t.ndim - 1)).contiguous() # NHWC/NDHWC
+ # Ambiguous (both >1 or equal small values) -> fail fast
+ if (c2 > 1 and cl > 1) or (c2 == cl and c2 in (1, 2, 3, 4)):
+ raise ValueError(
+ f"{name}: ambiguous channel dimension (dim1={c2}, dim-1={cl}). "
+ "Set num_classes explicitly or reorder the inputs."
+ )
+ return t
+
+ y_pred = _norm(y_pred, "y_pred")
+ y = _norm(y, "y")
+ # Inconsistent forms require num_classes
+ if self.num_classes is None and ((y_pred.shape[1] == 1) ^ (y.shape[1] == 1)):
+ raise ValueError(
+ "Inconsistent inputs: label-map vs one-hot but num_classes is None. "
+ "Provide num_classes to disambiguate."
+ )
Also clarify the DiceMetric docstring to state supported input layouts (NCHW[D], NHWC/NDHWC) and the explicit error conditions above.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
if isinstance(y_pred, torch.Tensor) and y_pred.ndim == 5 and y_pred.shape[-1] in (1, 3, 4): | |
y_pred = y_pred.permute(0, 4, 1, 2, 3).contiguous() | |
if isinstance(y, torch.Tensor) and y.ndim == 5 and y.shape[-1] in (1, 3, 4): | |
y = y.permute(0, 4, 1, 2, 3).contiguous() | |
# Normalize to channel-first; handle NHWC/NDHWC and fail fast on ambiguity. | |
def _norm(t: torch.Tensor, name: str) -> torch.Tensor: | |
if t.ndim not in (4, 5): | |
return t | |
c2, cl = t.shape[1], t.shape[-1] | |
# num_classes is authoritative when provided | |
if self.num_classes is not None: | |
if c2 in (self.num_classes, 1): | |
return t | |
if cl in (self.num_classes, 1): | |
return t.permute(0, t.ndim - 1, *range(1, t.ndim - 1)).contiguous() | |
raise ValueError( | |
f"{name}: cannot infer channel dimension with num_classes={self.num_classes}: " | |
f"dim1={c2}, dim-1={cl}." | |
) | |
# Heuristic: prefer the side where channels > 1 and the other side == 1 | |
if c2 > 1 and cl == 1: | |
return t # NCHW[D] | |
if cl > 1 and c2 == 1: | |
return t.permute(0, t.ndim - 1, *range(1, t.ndim - 1)).contiguous() # NHWC/NDHWC | |
# Ambiguous (both >1 or equal small values) -> fail fast | |
if (c2 > 1 and cl > 1) or (c2 == cl and c2 in (1, 2, 3, 4)): | |
raise ValueError( | |
f"{name}: ambiguous channel dimension (dim1={c2}, dim-1={cl}). " | |
"Set num_classes explicitly or reorder the inputs." | |
) | |
return t | |
y_pred = _norm(y_pred, "y_pred") | |
y = _norm(y, "y") | |
# Inconsistent forms require num_classes | |
if self.num_classes is None and ((y_pred.shape[1] == 1) ^ (y.shape[1] == 1)): | |
raise ValueError( | |
"Inconsistent inputs: label-map vs one-hot but num_classes is None. " | |
"Provide num_classes to disambiguate." | |
) |
🤖 Prompt for AI Agents
In monai/metrics/meandice.py around lines 138 to 142, the current NHWC-to-NCHW
permutation logic is unsafe: it blindly permutes 5D tensors when the last dim is
1/3/4, ignores 4D NHWC, and doesn't use num_classes to disambiguate label-map vs
one-hot inputs. Replace this block with a robust normalization routine that:
detects and supports NHWC (4D) and NDHWC/NDCHW (5D) layouts, uses num_classes as
an authoritative hint to decide whether the channel-last dimension is classes
(one-hot) or spatial (width), enforces consistency between num_classes and
channel size (raise ValueError on ambiguity), converts any NHWC/NDHWC input to
NCHW/NDCHW via explicit permute only after checks, and raises clear errors for
ambiguous shapes or mismatched label-map vs one-hot inputs; also update the
DiceMetric docstring to list supported layouts (NCHW[D], NHWC/NDHWC) and
enumerate the explicit error conditions.
What
Move ensure_channel_first to monai/utils/tensor_utils.py and re-export from inferers/utils.py for backward compatibility.
Update MeanDice channel normalization logic:
Use num_classes as authoritative hint when provided.
Without num_classes: apply heuristic first, then y-informed fallback.
Fail fast on ambiguous shapes (dim1 == dim-1 > 1).
Detect and raise on inconsistent inputs (label-map vs one-hot) when num_classes is missing.
Add unit tests covering:
NHWC with H==W==C and num_classes=None → expect ValueError.
NHWC with num_classes=C → matches NCHW baseline.
Label-map y_pred vs one-hot y → passes with num_classes=C, raises without.
Update docstrings to clarify channel-last support and explicit error conditions.
Why
Decouple metrics from inferers by placing generic tensor-layout helpers under utils.
Prevent silent NHWC/NCHW misinterpretation in DiceMetric.
Provide explicit errors and guidance when inputs are ambiguous or inconsistent.
How verified
Local checks: pre-commit run -a, mypy monai passed.
New tests in tests/metrics/test_dice_layouts.py cover ambiguity, parity with specified num_classes, and label-map vs one-hot cases.
CI expected to cover full platform matrix (Ubuntu/Windows/macOS, GPU backends).
Summary by CodeRabbit