Skip to content

Conversation

ferreirafabio80
Copy link

@ferreirafabio80 ferreirafabio80 commented Sep 3, 2025

Description

Introduces an optional use_checkpointing flag in the UNet implementation. When enabled, intermediate activations in the encoder–decoder blocks are recomputed during the backward pass instead of being stored in memory.

  • Implemented via a lightweight _ActivationCheckpointWrapper wrapper around sub-blocks.
  • Checkpointing is only applied during training to avoid overhead at inference.

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

Copy link

coderabbitai bot commented Sep 3, 2025

Walkthrough

Adds optional activation checkpointing to UNet: introduces a private _ActivationCheckpointWrapper that wraps subblocks and uses torch.utils.checkpoint.checkpoint (attempting use_reentrant=False and falling back on TypeError). UNet.init gains use_checkpointing: bool = False and stores it. When enabled, connection subblocks are wrapped with the checkpoint wrapper. The wrapper runs only when checkpointing is enabled and gradients are being computed; otherwise it executes the module normally. Documentation was updated to describe the new parameter and behavior.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbitai help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbitai ignore or @coderabbit ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@ferreirafabio80 ferreirafabio80 changed the title Feat/add activation checkpointing to unet feat: add activation checkpointing to unet Sep 3, 2025
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (4)
monai/networks/nets/unet.py (4)

29-33: Add a brief class docstring to the wrapper.
Improves discoverability and meets docstring guidelines.

Apply this diff:

 class _ActivationCheckpointWrapper(nn.Module):
-    def __init__(self, module: nn.Module) -> None:
+    """Apply activation checkpointing to the wrapped module during training."""
+    def __init__(self, module: nn.Module) -> None:
         super().__init__()
         self.module = module

134-135: Document the new use_checkpointing arg in the class docstring and user docs.
State trade-offs (memory vs compute), that it’s training-only, incompatible with torch.no_grad, and preserves RNG by default.

Proposed docstring snippet to add under “Args”:

use_checkpointing: if True, apply activation checkpointing to internal sub-blocks
    during training to reduce memory at the cost of extra compute. Checkpointing
    is bypassed in eval and when gradients are disabled. Defaults to False.

I can open a docs patch and add a short example enabling the flag.


163-164: Static wrapping caveat: runtime flips won’t take effect.
After init, changing self.use_checkpointing won’t rewrap existing blocks. Either document this or add a small helper to (re)build the model if you expect runtime toggling.

Do you expect users to toggle this at runtime? If yes, I can sketch a safe rewrap helper.


210-212: Checkpointing scope is subblock-only; consider an optional broader policy.
Current placement is a good default. If more memory is needed, offer a policy to also wrap down_path/up_path (with a warning about extra compute).

I can add a checkpoint_policy: Literal["subblock","all"] = "subblock" in __init__ and wire it here on request.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Knowledge Base: Disabled due to Reviews > Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 0968da2 and 66edcb5.

📒 Files selected for processing (1)
  • monai/networks/nets/unet.py (5 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/networks/nets/unet.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: min-dep-py3 (3.10)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: build-docs
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: packaging
🔇 Additional comments (1)
monai/networks/nets/unet.py (1)

16-21: Imports for checkpointing look good.
cast and checkpoint are appropriate for the new wrapper.

ferreirafabio80 and others added 3 commits September 4, 2025 15:36
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Signed-off-by: Fábio S. Ferreira <[email protected]>
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
monai/networks/nets/unet.py (1)

29-43: Good guard + compatibility fallback.
Training/grad-enabled checks and use_reentrant=False with TypeError fallback are the right call. This addresses the prior review note.

🧹 Nitpick comments (5)
monai/networks/nets/unet.py (5)

29-43: Avoid per-iteration TypeError cost: detect use_reentrant support once.
Resolve support at import/init time to prevent raising an exception every forward on older torch.

Apply:

@@
-class _ActivationCheckpointWrapper(nn.Module):
+_SUPPORTS_USE_REENTRANT: bool | None = None
+
+class _ActivationCheckpointWrapper(nn.Module):
@@
-    def forward(self, x: torch.Tensor) -> torch.Tensor:
-        if self.training and torch.is_grad_enabled() and x.requires_grad:
-            try:
-                return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))
-            except TypeError:
-                # Fallback for older PyTorch without `use_reentrant`
-                return cast(torch.Tensor, checkpoint(self.module, x))
-        return cast(torch.Tensor, self.module(x))
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        if self.training and torch.is_grad_enabled() and x.requires_grad:
+            global _SUPPORTS_USE_REENTRANT
+            if _SUPPORTS_USE_REENTRANT is None:
+                try:
+                    # probe once
+                    checkpoint(self.module, x, use_reentrant=False)  # type: ignore[arg-type]
+                    _SUPPORTS_USE_REENTRANT = True
+                except TypeError:
+                    _SUPPORTS_USE_REENTRANT = False
+                except Exception:
+                    # do not change behavior on unexpected errors; fall back below
+                    _SUPPORTS_USE_REENTRANT = False
+            if _SUPPORTS_USE_REENTRANT:
+                return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))
+            return cast(torch.Tensor, checkpoint(self.module, x))
+        return cast(torch.Tensor, self.module(x))

Add outside the hunk (file header):

import inspect  # if you switch to signature probing instead of try/except

Note: PyTorch recommends passing use_reentrant explicitly going forward. (docs.pytorch.org)


29-43: TorchScript: make wrapper script-safe.
try/except and dynamic checkpoint calls won’t script. Short-circuit under scripting.

Apply:

@@
-    def forward(self, x: torch.Tensor) -> torch.Tensor:
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        if torch.jit.is_scripting():
+            # Avoid checkpoint in scripted graphs
+            return cast(torch.Tensor, self.module(x))

29-43: Docstring completeness.
Add Google-style docstrings for the wrapper’s class/init/forward (inputs, returns, raises).

Example:

@@
-class _ActivationCheckpointWrapper(nn.Module):
-    """Apply activation checkpointing to the wrapped module during training."""
+class _ActivationCheckpointWrapper(nn.Module):
+    """Wrap a module and apply activation checkpointing during training.
+
+    Args:
+        module: The submodule to checkpoint.
+
+    Returns:
+        torch.Tensor: Output tensor from the wrapped submodule.
+
+    Raises:
+        RuntimeError: If checkpoint fails at runtime.
+    """

90-92: Tighten the use_checkpointing docstring and add a BN caveat.
Keep it on one Args entry and note the BatchNorm limitation.

-        use_checkpointing: if True, apply activation checkpointing to internal sub-blocks during training to reduce memory
-        at the cost of extra compute. Checkpointing is bypassed in eval and when gradients are disabled. Defaults to False.
+        use_checkpointing: If True, apply activation checkpointing to internal sub-blocks during training to reduce
+            memory at the cost of extra compute. Bypassed in eval and when gradients are disabled. Note: avoid with
+            BatchNorm layers due to running-stat updates during recomputation. Defaults to False.

Reference on RNG determinism (dropout is handled by default). (docs.pytorch.org)


217-219: Wrap site is fine; consider BN detection here instead of inside the wrapper.
Keeps policy close to where wrapping occurs and avoids per-instance warnings.

-        if self.use_checkpointing:
-            subblock = _ActivationCheckpointWrapper(subblock)
+        if self.use_checkpointing:
+            has_bn = any(isinstance(m, nn.modules.batchnorm._BatchNorm) for m in subblock.modules())
+            if has_bn:
+                warnings.warn(
+                    "Skipping activation checkpointing for this subblock (contains BatchNorm).",
+                    RuntimeWarning,
+                )
+            else:
+                subblock = _ActivationCheckpointWrapper(subblock)
         return nn.Sequential(down_path, SkipConnection(subblock), up_path)

Rationale: avoids double-updating BN stats during recomputation. (github.com)

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Knowledge Base: Disabled due to Reviews > Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between e66e357 and e112457.

📒 Files selected for processing (1)
  • monai/networks/nets/unet.py (6 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/networks/nets/unet.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: min-dep-py3 (3.10)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: build-docs
  • GitHub Check: packaging
  • GitHub Check: quick-py3 (ubuntu-latest)
🔇 Additional comments (3)
monai/networks/nets/unet.py (3)

16-16: LGTM: imports are correct and scoped.
cast and checkpoint are needed by the wrapper; no issues.

Also applies to: 20-20


141-141: Public API addition: ensure tests and docs cover new flag.
Add unit tests for parity (on/off), eval bypass, and no_grad() bypass; document in release notes/configs.

Proposed minimal tests:

  • Forward/backward equivalence within tolerance for use_checkpointing={False,True} on a tiny UNet.
  • model.eval() and torch.no_grad() paths skip checkpoint (no checkpoint mocks invoked).
  • Autocast path under CUDA runs without dtype mismatches.

Do you want a patch with pytest tests added under tests/test_unet_checkpointing.py?


170-170: LGTM: stores flag on the instance.
No concerns.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (4)
monai/networks/nets/unet.py (4)

29-51: Warn once for BN, and add param/return docstring per guidelines.

Avoid repeated RuntimeWarnings and document the wrapper’s contract.

 class _ActivationCheckpointWrapper(nn.Module):
-    """Apply activation checkpointing to the wrapped module during training."""
+    """Apply activation checkpointing to the wrapped module during training.
+
+    Args:
+        module: submodule to wrap.
+    Returns:
+        torch.Tensor: output of the wrapped module.
+    Warnings:
+        Skips checkpointing and emits a RuntimeWarning if the submodule contains
+        BatchNorm to avoid double-updating running stats.
+    """
     def __init__(self, module: nn.Module) -> None:
         super().__init__()
         # Pre-detect BatchNorm presence for fast path
         self._has_bn = any(isinstance(m, nn.modules.batchnorm._BatchNorm) for m in module.modules())
         self.module = module
+        self._bn_warned = False
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
         if self.training and torch.is_grad_enabled() and x.requires_grad:
             if self._has_bn:
-                warnings.warn(
-                    "Activation checkpointing skipped for a subblock containing BatchNorm to avoid double-updating "
-                    "running statistics during recomputation.",
-                    RuntimeWarning,
-                )
+                if not self._bn_warned:
+                    warnings.warn(
+                        "Activation checkpointing skipped for a subblock containing BatchNorm to avoid double-updating "
+                        "running statistics during recomputation.",
+                        RuntimeWarning,
+                    )
+                    self._bn_warned = True
                 return cast(torch.Tensor, self.module(x))
             try:
                 return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))
             except TypeError:
                 # Fallback for older PyTorch without `use_reentrant`
                 return cast(torch.Tensor, checkpoint(self.module, x))
         return cast(torch.Tensor, self.module(x))

Minimal tests to add:

  • Training vs eval parity (values match with/without checkpointing).
  • BN subblock emits RuntimeWarning and bypasses checkpointing.
  • Guard under torch.no_grad() and when requires_grad=False.

99-101: Clarify use_checkpointing behavior in docs (BN, grad, training-only, build-time).

Make expectations explicit for users.

-        use_checkpointing: if True, apply activation checkpointing to internal sub-blocks during training to reduce memory
-        at the cost of extra compute. Checkpointing is bypassed in eval and when gradients are disabled. Defaults to False.
+        use_checkpointing: if True, apply activation checkpointing to internal sub-blocks during training to reduce
+            memory at the cost of extra compute. Checkpointing is active only when `self.training` is True, gradients
+            are enabled, and inputs require gradients; it is bypassed in eval and when grads are disabled.
+            Sub-blocks containing BatchNorm are not checkpointed to avoid double-updating running stats.
+            Note: this is a build-time option; changing it after initialization will not retroactively wrap existing
+            sub-blocks. Defaults to False.

179-180: Flag is build-time only; consider asserting or documenting.

Changing self.use_checkpointing post-init has no effect since wrappers are created during construction. The doc update above covers this; alternatively, convert to a read-only attribute to avoid confusion.


226-228: Wrapping only the subblock is OK; consider optional coverage toggle.

If desired later, expose an opt-in to also wrap down/up blocks for additional savings.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Knowledge Base: Disabled due to Reviews > Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between e112457 and f673ca1.

📒 Files selected for processing (1)
  • monai/networks/nets/unet.py (6 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/networks/nets/unet.py
🔇 Additional comments (2)
monai/networks/nets/unet.py (2)

16-21: Imports look good.

Scoped import of checkpoint plus cast is appropriate.


150-151: API change is sensible and non-breaking.

Parameter added at the end; default preserves behavior.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
monai/networks/nets/unet.py (1)

29-43: Skip checkpointing for subblocks with BatchNorm to avoid double-updating running stats.
Checkpoint recompute updates BN running stats twice under training. Detect BN in the wrapped module and bypass checkpointing with a warning.

Apply this diff:

 class _ActivationCheckpointWrapper(nn.Module):
-    """Apply activation checkpointing to the wrapped module during training."""
+    """Apply activation checkpointing to the wrapped module during training.

+    Skips checkpointing for submodules containing BatchNorm to avoid double-updating
+    running statistics during recomputation.
+    """
     def __init__(self, module: nn.Module) -> None:
         super().__init__()
         self.module = module
+        self._has_bn = any(isinstance(m, nn.modules.batchnorm._BatchNorm) for m in module.modules())

     def forward(self, x: torch.Tensor) -> torch.Tensor:
-        if self.training and torch.is_grad_enabled() and x.requires_grad:
+        if self.training and torch.is_grad_enabled() and x.requires_grad:
+            if self._has_bn:
+                warnings.warn(
+                    "Activation checkpointing skipped for a subblock containing BatchNorm "
+                    "to avoid double-updating running statistics during recomputation.",
+                    RuntimeWarning,
+                )
+                return cast(torch.Tensor, self.module(x))
             try:
                 return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))
             except TypeError:
                 # Fallback for older PyTorch without `use_reentrant`
                 return cast(torch.Tensor, checkpoint(self.module, x))
         return cast(torch.Tensor, self.module(x))
🧹 Nitpick comments (3)
monai/networks/nets/unet.py (3)

90-92: Clarify arg docs and surface BN caveat.
Tighten wording and document BN behavior for transparency.

Apply this diff:

-        use_checkpointing: if True, apply activation checkpointing to internal sub-blocks during training to reduce memory
-        at the cost of extra compute. Checkpointing is bypassed in eval and when gradients are disabled. Defaults to False.
+        use_checkpointing: If True, applies activation checkpointing to internal sub-blocks during training to reduce
+            memory at the cost of extra compute. Bypassed in eval mode and when gradients are disabled.
+            Note: sub-blocks containing BatchNorm are executed without checkpointing to avoid double-updating
+            running statistics. Defaults to False.

217-219: Placement of wrapper is sensible; consider optional breadth control.
Future enhancement: expose a knob to checkpoint down/up paths too for deeper memory savings on very deep nets.


141-142: Add tests to lock behavior.

  • Parity: forward/backward equivalence (outputs/grad norms) with vs. without checkpointing.
  • Modes: train vs. eval; torch.no_grad().
  • Norms: with InstanceNorm and with BatchNorm (assert BN path skips with warning).

I can draft unit tests targeting UNet’s smallest config to keep runtime minimal—want me to open a follow-up?

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Knowledge Base: Disabled due to Reviews > Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between f673ca1 and 69540ff.

📒 Files selected for processing (1)
  • monai/networks/nets/unet.py (6 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/networks/nets/unet.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: build-docs
  • GitHub Check: packaging
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: min-dep-py3 (3.10)
🔇 Additional comments (3)
monai/networks/nets/unet.py (3)

16-21: LGTM: imports for cast/checkpoint are correct.
Direct import of checkpoint and use of typing.cast are appropriate.


35-42: Validate AMP behavior under fallback (reentrant) checkpointing.
Older Torch (fallback path) may not replay autocast exactly; please verify mixed-precision parity.

Minimal check: run a forward/backward with torch.autocast and compare loss/grad norms with and without checkpointing on a small UNet to ensure deltas are within numerical noise.


141-142: API addition looks good.
Name and default match MONAI conventions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant