-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feat: add activation checkpointing to unet #8554
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
feat: add activation checkpointing to unet #8554
Conversation
WalkthroughAdds optional activation checkpointing to UNet: introduces a private _ActivationCheckpointWrapper that wraps subblocks and uses torch.utils.checkpoint.checkpoint (attempting Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes ✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
Status, Documentation and Community
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (4)
monai/networks/nets/unet.py (4)
29-33
: Add a brief class docstring to the wrapper.
Improves discoverability and meets docstring guidelines.Apply this diff:
class _ActivationCheckpointWrapper(nn.Module): - def __init__(self, module: nn.Module) -> None: + """Apply activation checkpointing to the wrapped module during training.""" + def __init__(self, module: nn.Module) -> None: super().__init__() self.module = module
134-135
: Document the newuse_checkpointing
arg in the class docstring and user docs.
State trade-offs (memory vs compute), that it’s training-only, incompatible withtorch.no_grad
, and preserves RNG by default.Proposed docstring snippet to add under “Args”:
use_checkpointing: if True, apply activation checkpointing to internal sub-blocks during training to reduce memory at the cost of extra compute. Checkpointing is bypassed in eval and when gradients are disabled. Defaults to False.I can open a docs patch and add a short example enabling the flag.
163-164
: Static wrapping caveat: runtime flips won’t take effect.
After init, changingself.use_checkpointing
won’t rewrap existing blocks. Either document this or add a small helper to (re)build the model if you expect runtime toggling.Do you expect users to toggle this at runtime? If yes, I can sketch a safe rewrap helper.
210-212
: Checkpointing scope is subblock-only; consider an optional broader policy.
Current placement is a good default. If more memory is needed, offer a policy to also wrapdown_path
/up_path
(with a warning about extra compute).I can add a
checkpoint_policy: Literal["subblock","all"] = "subblock"
in__init__
and wire it here on request.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Knowledge Base: Disabled due to Reviews > Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/networks/nets/unet.py
(5 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/networks/nets/unet.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: build-docs
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: packaging
🔇 Additional comments (1)
monai/networks/nets/unet.py (1)
16-21
: Imports for checkpointing look good.
cast
andcheckpoint
are appropriate for the new wrapper.
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: Fábio S. Ferreira <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
monai/networks/nets/unet.py (1)
29-43
: Good guard + compatibility fallback.
Training/grad-enabled checks anduse_reentrant=False
withTypeError
fallback are the right call. This addresses the prior review note.
🧹 Nitpick comments (5)
monai/networks/nets/unet.py (5)
29-43
: Avoid per-iteration TypeError cost: detectuse_reentrant
support once.
Resolve support at import/init time to prevent raising an exception every forward on older torch.Apply:
@@ -class _ActivationCheckpointWrapper(nn.Module): +_SUPPORTS_USE_REENTRANT: bool | None = None + +class _ActivationCheckpointWrapper(nn.Module): @@ - def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.training and torch.is_grad_enabled() and x.requires_grad: - try: - return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) - except TypeError: - # Fallback for older PyTorch without `use_reentrant` - return cast(torch.Tensor, checkpoint(self.module, x)) - return cast(torch.Tensor, self.module(x)) + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.training and torch.is_grad_enabled() and x.requires_grad: + global _SUPPORTS_USE_REENTRANT + if _SUPPORTS_USE_REENTRANT is None: + try: + # probe once + checkpoint(self.module, x, use_reentrant=False) # type: ignore[arg-type] + _SUPPORTS_USE_REENTRANT = True + except TypeError: + _SUPPORTS_USE_REENTRANT = False + except Exception: + # do not change behavior on unexpected errors; fall back below + _SUPPORTS_USE_REENTRANT = False + if _SUPPORTS_USE_REENTRANT: + return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) + return cast(torch.Tensor, checkpoint(self.module, x)) + return cast(torch.Tensor, self.module(x))Add outside the hunk (file header):
import inspect # if you switch to signature probing instead of try/exceptNote: PyTorch recommends passing
use_reentrant
explicitly going forward. (docs.pytorch.org)
29-43
: TorchScript: make wrapper script-safe.
try/except
and dynamic checkpoint calls won’t script. Short-circuit under scripting.Apply:
@@ - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: + if torch.jit.is_scripting(): + # Avoid checkpoint in scripted graphs + return cast(torch.Tensor, self.module(x))
29-43
: Docstring completeness.
Add Google-style docstrings for the wrapper’s class/init/forward (inputs, returns, raises).Example:
@@ -class _ActivationCheckpointWrapper(nn.Module): - """Apply activation checkpointing to the wrapped module during training.""" +class _ActivationCheckpointWrapper(nn.Module): + """Wrap a module and apply activation checkpointing during training. + + Args: + module: The submodule to checkpoint. + + Returns: + torch.Tensor: Output tensor from the wrapped submodule. + + Raises: + RuntimeError: If checkpoint fails at runtime. + """
90-92
: Tighten theuse_checkpointing
docstring and add a BN caveat.
Keep it on one Args entry and note the BatchNorm limitation.- use_checkpointing: if True, apply activation checkpointing to internal sub-blocks during training to reduce memory - at the cost of extra compute. Checkpointing is bypassed in eval and when gradients are disabled. Defaults to False. + use_checkpointing: If True, apply activation checkpointing to internal sub-blocks during training to reduce + memory at the cost of extra compute. Bypassed in eval and when gradients are disabled. Note: avoid with + BatchNorm layers due to running-stat updates during recomputation. Defaults to False.Reference on RNG determinism (dropout is handled by default). (docs.pytorch.org)
217-219
: Wrap site is fine; consider BN detection here instead of inside the wrapper.
Keeps policy close to where wrapping occurs and avoids per-instance warnings.- if self.use_checkpointing: - subblock = _ActivationCheckpointWrapper(subblock) + if self.use_checkpointing: + has_bn = any(isinstance(m, nn.modules.batchnorm._BatchNorm) for m in subblock.modules()) + if has_bn: + warnings.warn( + "Skipping activation checkpointing for this subblock (contains BatchNorm).", + RuntimeWarning, + ) + else: + subblock = _ActivationCheckpointWrapper(subblock) return nn.Sequential(down_path, SkipConnection(subblock), up_path)Rationale: avoids double-updating BN stats during recomputation. (github.com)
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Knowledge Base: Disabled due to Reviews > Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/networks/nets/unet.py
(6 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/networks/nets/unet.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: build-docs
- GitHub Check: packaging
- GitHub Check: quick-py3 (ubuntu-latest)
🔇 Additional comments (3)
monai/networks/nets/unet.py (3)
16-16
: LGTM: imports are correct and scoped.
cast
andcheckpoint
are needed by the wrapper; no issues.Also applies to: 20-20
141-141
: Public API addition: ensure tests and docs cover new flag.
Add unit tests for parity (on/off), eval bypass, andno_grad()
bypass; document in release notes/configs.Proposed minimal tests:
- Forward/backward equivalence within tolerance for
use_checkpointing={False,True}
on a tiny UNet.model.eval()
andtorch.no_grad()
paths skip checkpoint (nocheckpoint
mocks invoked).- Autocast path under CUDA runs without dtype mismatches.
Do you want a patch with pytest tests added under
tests/test_unet_checkpointing.py
?
170-170
: LGTM: stores flag on the instance.
No concerns.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (4)
monai/networks/nets/unet.py (4)
29-51
: Warn once for BN, and add param/return docstring per guidelines.Avoid repeated RuntimeWarnings and document the wrapper’s contract.
class _ActivationCheckpointWrapper(nn.Module): - """Apply activation checkpointing to the wrapped module during training.""" + """Apply activation checkpointing to the wrapped module during training. + + Args: + module: submodule to wrap. + Returns: + torch.Tensor: output of the wrapped module. + Warnings: + Skips checkpointing and emits a RuntimeWarning if the submodule contains + BatchNorm to avoid double-updating running stats. + """ def __init__(self, module: nn.Module) -> None: super().__init__() # Pre-detect BatchNorm presence for fast path self._has_bn = any(isinstance(m, nn.modules.batchnorm._BatchNorm) for m in module.modules()) self.module = module + self._bn_warned = False def forward(self, x: torch.Tensor) -> torch.Tensor: if self.training and torch.is_grad_enabled() and x.requires_grad: if self._has_bn: - warnings.warn( - "Activation checkpointing skipped for a subblock containing BatchNorm to avoid double-updating " - "running statistics during recomputation.", - RuntimeWarning, - ) + if not self._bn_warned: + warnings.warn( + "Activation checkpointing skipped for a subblock containing BatchNorm to avoid double-updating " + "running statistics during recomputation.", + RuntimeWarning, + ) + self._bn_warned = True return cast(torch.Tensor, self.module(x)) try: return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) except TypeError: # Fallback for older PyTorch without `use_reentrant` return cast(torch.Tensor, checkpoint(self.module, x)) return cast(torch.Tensor, self.module(x))Minimal tests to add:
- Training vs eval parity (values match with/without checkpointing).
- BN subblock emits RuntimeWarning and bypasses checkpointing.
- Guard under
torch.no_grad()
and whenrequires_grad=False
.
99-101
: Clarifyuse_checkpointing
behavior in docs (BN, grad, training-only, build-time).Make expectations explicit for users.
- use_checkpointing: if True, apply activation checkpointing to internal sub-blocks during training to reduce memory - at the cost of extra compute. Checkpointing is bypassed in eval and when gradients are disabled. Defaults to False. + use_checkpointing: if True, apply activation checkpointing to internal sub-blocks during training to reduce + memory at the cost of extra compute. Checkpointing is active only when `self.training` is True, gradients + are enabled, and inputs require gradients; it is bypassed in eval and when grads are disabled. + Sub-blocks containing BatchNorm are not checkpointed to avoid double-updating running stats. + Note: this is a build-time option; changing it after initialization will not retroactively wrap existing + sub-blocks. Defaults to False.
179-180
: Flag is build-time only; consider asserting or documenting.Changing
self.use_checkpointing
post-init has no effect since wrappers are created during construction. The doc update above covers this; alternatively, convert to a read-only attribute to avoid confusion.
226-228
: Wrapping only the subblock is OK; consider optional coverage toggle.If desired later, expose an opt-in to also wrap
down
/up
blocks for additional savings.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Knowledge Base: Disabled due to Reviews > Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/networks/nets/unet.py
(6 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/networks/nets/unet.py
🔇 Additional comments (2)
monai/networks/nets/unet.py (2)
16-21
: Imports look good.Scoped import of
checkpoint
pluscast
is appropriate.
150-151
: API change is sensible and non-breaking.Parameter added at the end; default preserves behavior.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
monai/networks/nets/unet.py (1)
29-43
: Skip checkpointing for subblocks with BatchNorm to avoid double-updating running stats.
Checkpoint recompute updates BN running stats twice under training. Detect BN in the wrapped module and bypass checkpointing with a warning.Apply this diff:
class _ActivationCheckpointWrapper(nn.Module): - """Apply activation checkpointing to the wrapped module during training.""" + """Apply activation checkpointing to the wrapped module during training. + Skips checkpointing for submodules containing BatchNorm to avoid double-updating + running statistics during recomputation. + """ def __init__(self, module: nn.Module) -> None: super().__init__() self.module = module + self._has_bn = any(isinstance(m, nn.modules.batchnorm._BatchNorm) for m in module.modules()) def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.training and torch.is_grad_enabled() and x.requires_grad: + if self.training and torch.is_grad_enabled() and x.requires_grad: + if self._has_bn: + warnings.warn( + "Activation checkpointing skipped for a subblock containing BatchNorm " + "to avoid double-updating running statistics during recomputation.", + RuntimeWarning, + ) + return cast(torch.Tensor, self.module(x)) try: return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) except TypeError: # Fallback for older PyTorch without `use_reentrant` return cast(torch.Tensor, checkpoint(self.module, x)) return cast(torch.Tensor, self.module(x))
🧹 Nitpick comments (3)
monai/networks/nets/unet.py (3)
90-92
: Clarify arg docs and surface BN caveat.
Tighten wording and document BN behavior for transparency.Apply this diff:
- use_checkpointing: if True, apply activation checkpointing to internal sub-blocks during training to reduce memory - at the cost of extra compute. Checkpointing is bypassed in eval and when gradients are disabled. Defaults to False. + use_checkpointing: If True, applies activation checkpointing to internal sub-blocks during training to reduce + memory at the cost of extra compute. Bypassed in eval mode and when gradients are disabled. + Note: sub-blocks containing BatchNorm are executed without checkpointing to avoid double-updating + running statistics. Defaults to False.
217-219
: Placement of wrapper is sensible; consider optional breadth control.
Future enhancement: expose a knob to checkpoint down/up paths too for deeper memory savings on very deep nets.
141-142
: Add tests to lock behavior.
- Parity: forward/backward equivalence (outputs/grad norms) with vs. without checkpointing.
- Modes: train vs. eval; torch.no_grad().
- Norms: with InstanceNorm and with BatchNorm (assert BN path skips with warning).
I can draft unit tests targeting UNet’s smallest config to keep runtime minimal—want me to open a follow-up?
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Knowledge Base: Disabled due to Reviews > Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/networks/nets/unet.py
(6 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/networks/nets/unet.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: build-docs
- GitHub Check: packaging
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-py3 (3.10)
🔇 Additional comments (3)
monai/networks/nets/unet.py (3)
16-21
: LGTM: imports for cast/checkpoint are correct.
Direct import of checkpoint and use of typing.cast are appropriate.
35-42
: Validate AMP behavior under fallback (reentrant) checkpointing.
Older Torch (fallback path) may not replay autocast exactly; please verify mixed-precision parity.Minimal check: run a forward/backward with torch.autocast and compare loss/grad norms with and without checkpointing on a small UNet to ensure deltas are within numerical noise.
141-142
: API addition looks good.
Name and default match MONAI conventions.
Description
Introduces an optional
use_checkpointing
flag in theUNet
implementation. When enabled, intermediate activations in the encoder–decoder blocks are recomputed during the backward pass instead of being stored in memory._ActivationCheckpointWrapper
wrapper around sub-blocks.Types of changes
./runtests.sh -f -u --net --coverage
../runtests.sh --quick --unittests --disttests
.make html
command in thedocs/
folder.